merge with main and resolve conflicts

This commit is contained in:
Lincoln Stein 2024-05-27 22:20:34 -04:00
commit 34e1eb19f9
256 changed files with 9360 additions and 6061 deletions

View File

@ -64,7 +64,7 @@ GPU_DRIVER=nvidia
Any environment variables supported by InvokeAI can be set here - please see the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail. Any environment variables supported by InvokeAI can be set here - please see the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail.
## Even Moar Customizing! ## Even More Customizing!
See the `docker-compose.yml` file. The `command` instruction can be uncommented and used to run arbitrary startup commands. Some examples below. See the `docker-compose.yml` file. The `command` instruction can be uncommented and used to run arbitrary startup commands. Some examples below.

View File

@ -117,13 +117,13 @@ Stateless fields do not store their value in the node, so their field instances
"Custom" fields will always be treated as stateless fields. "Custom" fields will always be treated as stateless fields.
##### Collection and Scalar Fields ##### Single and Collection Fields
Field types have a name and two flags which may identify it as a **collection** or **collection or scalar** field. Field types have a name and cardinality property which may identify it as a **SINGLE**, **COLLECTION** or **SINGLE_OR_COLLECTION** field.
If a field is annotated in python as a list, its field type is parsed and flagged as a **collection** type (e.g. `list[int]`). - If a field is annotated in python as a singular value or class, its field type is parsed as a **SINGLE** type (e.g. `int`, `ImageField`, `str`).
- If a field is annotated in python as a list, its field type is parsed as a **COLLECTION** type (e.g. `list[int]`).
If it is annotated as a union of a type and list, the type will be flagged as a **collection or scalar** type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed). - If it is annotated as a union of a type and list, the type will be parsed as a **SINGLE_OR_COLLECTION** type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed).
## Implementation ## Implementation
@ -173,8 +173,7 @@ Field types are represented as structured objects:
```ts ```ts
type FieldType = { type FieldType = {
name: string; name: string;
isCollection: boolean; cardinality: 'SINGLE' | 'COLLECTION' | 'SINGLE_OR_COLLECTION';
isCollectionOrScalar: boolean;
}; };
``` ```
@ -186,7 +185,7 @@ There are 4 general cases for field type parsing.
When a field is annotated as a primitive values (e.g. `int`, `str`, `float`), the field type parsing is fairly straightforward. The field is represented by a simple OpenAPI **schema object**, which has a `type` property. When a field is annotated as a primitive values (e.g. `int`, `str`, `float`), the field type parsing is fairly straightforward. The field is represented by a simple OpenAPI **schema object**, which has a `type` property.
We create a field type name from this `type` string (e.g. `string` -> `StringField`). We create a field type name from this `type` string (e.g. `string` -> `StringField`). The cardinality is `"SINGLE"`.
##### Complex Types ##### Complex Types
@ -200,13 +199,13 @@ We need to **dereference** the schema to pull these out. Dereferencing may requi
When a field is annotated as a list of a single type, the schema object has an `items` property. They may be a schema object or reference object and must be parsed to determine the item type. When a field is annotated as a list of a single type, the schema object has an `items` property. They may be a schema object or reference object and must be parsed to determine the item type.
We use the item type for field type name, adding `isCollection: true` to the field type. We use the item type for field type name. The cardinality is `"COLLECTION"`.
##### Collection or Scalar Types ##### Single or Collection Types
When a field is annotated as a union of a type and list of that type, the schema object has an `anyOf` property, which holds a list of valid types for the union. When a field is annotated as a union of a type and list of that type, the schema object has an `anyOf` property, which holds a list of valid types for the union.
After verifying that the union has two members (a type and list of the same type), we use the type for field type name, adding `isCollectionOrScalar: true` to the field type. After verifying that the union has two members (a type and list of the same type), we use the type for field type name, with cardinality `"SINGLE_OR_COLLECTION"`.
##### Optional Fields ##### Optional Fields

View File

@ -165,7 +165,7 @@ Additionally, each section can be expanded with the "Show Advanced" button in o
There are several ways to install IP-Adapter models with an existing InvokeAI installation: There are several ways to install IP-Adapter models with an existing InvokeAI installation:
1. Through the command line interface launched from the invoke.sh / invoke.bat scripts, option [4] to download models. 1. Through the command line interface launched from the invoke.sh / invoke.bat scripts, option [4] to download models.
2. Through the Model Manager UI with models from the *Tools* section of [www.models.invoke.ai](https://www.models.invoke.ai). To do this, copy the repo ID from the desired model page, and paste it in the Add Model field of the model manager. **Note** Both the IP-Adapter and the Image Encoder must be installed for IP-Adapter to work. For example, the [SD 1.5 IP-Adapter](https://models.invoke.ai/InvokeAI/ip_adapter_plus_sd15) and [SD1.5 Image Encoder](https://models.invoke.ai/InvokeAI/ip_adapter_sd_image_encoder) must be installed to use IP-Adapter with SD1.5 based models. 2. Through the Model Manager UI with models from the *Tools* section of [models.invoke.ai](https://models.invoke.ai). To do this, copy the repo ID from the desired model page, and paste it in the Add Model field of the model manager. **Note** Both the IP-Adapter and the Image Encoder must be installed for IP-Adapter to work. For example, the [SD 1.5 IP-Adapter](https://models.invoke.ai/InvokeAI/ip_adapter_plus_sd15) and [SD1.5 Image Encoder](https://models.invoke.ai/InvokeAI/ip_adapter_sd_image_encoder) must be installed to use IP-Adapter with SD1.5 based models.
3. **Advanced -- Not recommended ** Manually downloading the IP-Adapter and Image Encoder files - Image Encoder folders shouid be placed in the `models\any\clip_vision` folders. IP Adapter Model folders should be placed in the relevant `ip-adapter` folder of relevant base model folder of Invoke root directory. For example, for the SDXL IP-Adapter, files should be added to the `model/sdxl/ip_adapter/` folder. 3. **Advanced -- Not recommended ** Manually downloading the IP-Adapter and Image Encoder files - Image Encoder folders shouid be placed in the `models\any\clip_vision` folders. IP Adapter Model folders should be placed in the relevant `ip-adapter` folder of relevant base model folder of Invoke root directory. For example, for the SDXL IP-Adapter, files should be added to the `model/sdxl/ip_adapter/` folder.
#### Using IP-Adapter #### Using IP-Adapter

View File

@ -20,7 +20,7 @@ When you generate an image using text-to-image, multiple steps occur in latent s
4. The VAE decodes the final latent image from latent space into image space. 4. The VAE decodes the final latent image from latent space into image space.
Image-to-image is a similar process, with only step 1 being different: Image-to-image is a similar process, with only step 1 being different:
1. The input image is encoded from image space into latent space by the VAE. Noise is then added to the input latent image. Denoising Strength dictates how may noise steps are added, and the amount of noise added at each step. A Denoising Strength of 0 means there are 0 steps and no noise added, resulting in an unchanged image, while a Denoising Strength of 1 results in the image being completely replaced with noise and a full set of denoising steps are performance. The process is then the same as steps 2-4 in the text-to-image process. 1. The input image is encoded from image space into latent space by the VAE. Noise is then added to the input latent image. Denoising Strength dictates how many noise steps are added, and the amount of noise added at each step. A Denoising Strength of 0 means there are 0 steps and no noise added, resulting in an unchanged image, while a Denoising Strength of 1 results in the image being completely replaced with noise and a full set of denoising steps are performance. The process is then the same as steps 2-4 in the text-to-image process.
Furthermore, a model provides the CLIP prompt tokenizer, the VAE, and a U-Net (where noise prediction occurs given a prompt and initial noise tensor). Furthermore, a model provides the CLIP prompt tokenizer, the VAE, and a U-Net (where noise prediction occurs given a prompt and initial noise tensor).

View File

@ -10,7 +10,7 @@ InvokeAI is distributed as a python package on PyPI, installable with `pip`. The
### Requirements ### Requirements
Before you start, go through the [installation requirements]. Before you start, go through the [installation requirements](./INSTALL_REQUIREMENTS.md).
### Installation Walkthrough ### Installation Walkthrough
@ -79,7 +79,7 @@ Before you start, go through the [installation requirements].
1. Install the InvokeAI Package. The base command is `pip install InvokeAI --use-pep517`, but you may need to change this depending on your system and the desired features. 1. Install the InvokeAI Package. The base command is `pip install InvokeAI --use-pep517`, but you may need to change this depending on your system and the desired features.
- You may need to provide an [extra index URL]. Select your platform configuration using [this tool on the PyTorch website]. Copy the `--extra-index-url` string from this and append it to your install command. - You may need to provide an [extra index URL](https://pip.pypa.io/en/stable/cli/pip_install/#cmdoption-extra-index-url). Select your platform configuration using [this tool on the PyTorch website](https://pytorch.org/get-started/locally/). Copy the `--extra-index-url` string from this and append it to your install command.
!!! example "Install with an extra index URL" !!! example "Install with an extra index URL"
@ -116,4 +116,4 @@ Before you start, go through the [installation requirements].
!!! warning !!! warning
If the virtual environment is _not_ inside the root directory, then you _must_ specify the path to the root directory with `--root_dir \path\to\invokeai` or the `INVOKEAI_ROOT` environment variable. If the virtual environment is _not_ inside the root directory, then you _must_ specify the path to the root directory with `--root \path\to\invokeai` or the `INVOKEAI_ROOT` environment variable.

View File

@ -10,8 +10,7 @@ set INVOKEAI_ROOT=.
echo Desired action: echo Desired action:
echo 1. Generate images with the browser-based interface echo 1. Generate images with the browser-based interface
echo 2. Open the developer console echo 2. Open the developer console
echo 3. Run the InvokeAI image database maintenance script echo 3. Command-line help
echo 4. Command-line help
echo Q - Quit echo Q - Quit
echo. echo.
echo To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest. echo To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest.
@ -34,9 +33,6 @@ IF /I "%choice%" == "1" (
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment *** echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
call cmd /k call cmd /k
) ELSE IF /I "%choice%" == "3" ( ) ELSE IF /I "%choice%" == "3" (
echo Running the db maintenance script...
python .venv\Scripts\invokeai-db-maintenance.exe
) ELSE IF /I "%choice%" == "4" (
echo Displaying command line help... echo Displaying command line help...
python .venv\Scripts\invokeai-web.exe --help %* python .venv\Scripts\invokeai-web.exe --help %*
pause pause

View File

@ -47,11 +47,6 @@ do_choice() {
bash --init-file "$file_name" bash --init-file "$file_name"
;; ;;
3) 3)
clear
printf "Running the db maintenance script\n"
invokeai-db-maintenance --root ${INVOKEAI_ROOT}
;;
4)
clear clear
printf "Command-line help\n" printf "Command-line help\n"
invokeai-web --help invokeai-web --help
@ -71,8 +66,7 @@ do_line_input() {
printf "What would you like to do?\n" printf "What would you like to do?\n"
printf "1: Generate images using the browser-based interface\n" printf "1: Generate images using the browser-based interface\n"
printf "2: Open the developer console\n" printf "2: Open the developer console\n"
printf "3: Run the InvokeAI image database maintenance script\n" printf "3: Command-line help\n"
printf "4: Command-line help\n"
printf "Q: Quit\n\n" printf "Q: Quit\n\n"
printf "To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest.\n\n" printf "To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest.\n\n"
read -p "Please enter 1-4, Q: [1] " yn read -p "Please enter 1-4, Q: [1] " yn

View File

@ -18,6 +18,7 @@ from ..services.boards.boards_default import BoardService
from ..services.bulk_download.bulk_download_default import BulkDownloadService from ..services.bulk_download.bulk_download_default import BulkDownloadService
from ..services.config import InvokeAIAppConfig from ..services.config import InvokeAIAppConfig
from ..services.download import DownloadQueueService from ..services.download import DownloadQueueService
from ..services.events.events_fastapievents import FastAPIEventService
from ..services.image_files.image_files_disk import DiskImageFileStorage from ..services.image_files.image_files_disk import DiskImageFileStorage
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
from ..services.images.images_default import ImageService from ..services.images.images_default import ImageService
@ -29,11 +30,10 @@ from ..services.model_images.model_images_default import ModelImageFileStorageDi
from ..services.model_manager.model_manager_default import ModelManagerService from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor from ..services.session_processor.session_processor_default import DefaultSessionProcessor, DefaultSessionRunner
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
from ..services.urls.urls_default import LocalUrlService from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from .events import FastAPIEventService
# TODO: is there a better way to achieve this? # TODO: is there a better way to achieve this?
@ -103,7 +103,7 @@ class ApiDependencies:
) )
names = SimpleNameService() names = SimpleNameService()
performance_statistics = InvocationStatsService() performance_statistics = InvocationStatsService()
session_processor = DefaultSessionProcessor() session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner())
session_queue = SqliteSessionQueue(db=db) session_queue = SqliteSessionQueue(db=db)
urls = LocalUrlService() urls = LocalUrlService()
workflow_records = SqliteWorkflowRecordsStorage(db=db) workflow_records = SqliteWorkflowRecordsStorage(db=db)

View File

@ -1,52 +0,0 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
import threading
from queue import Empty, Queue
from typing import Any
from fastapi_events.dispatcher import dispatch
from ..services.events.events_base import EventServiceBase
class FastAPIEventService(EventServiceBase):
event_handler_id: int
__queue: Queue
__stop_event: threading.Event
def __init__(self, event_handler_id: int) -> None:
self.event_handler_id = event_handler_id
self.__queue = Queue()
self.__stop_event = threading.Event()
asyncio.create_task(self.__dispatch_from_queue(stop_event=self.__stop_event))
super().__init__()
def stop(self, *args, **kwargs):
self.__stop_event.set()
self.__queue.put(None)
def dispatch(self, event_name: str, payload: Any) -> None:
self.__queue.put({"event_name": event_name, "payload": payload})
async def __dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""
while not stop_event.is_set():
try:
event = self.__queue.get(block=False)
if not event: # Probably stopping
continue
dispatch(
event.get("event_name"),
payload=event.get("payload"),
middleware_id=self.event_handler_id,
)
except Empty:
await asyncio.sleep(0.1)
pass
except asyncio.CancelledError as e:
raise e # Raise a proper error

View File

@ -6,7 +6,7 @@ from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request,
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from PIL import Image from PIL import Image
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, JsonValue
from invokeai.app.invocations.fields import MetadataField from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
@ -41,14 +41,17 @@ async def upload_image(
board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"), board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"),
session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"), session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"),
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"), crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
metadata: Optional[JsonValue] = Body(
default=None, description="The metadata to associate with the image", embed=True
),
) -> ImageDTO: ) -> ImageDTO:
"""Uploads an image""" """Uploads an image"""
if not file.content_type or not file.content_type.startswith("image"): if not file.content_type or not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image") raise HTTPException(status_code=415, detail="Not an image")
metadata = None _metadata = None
workflow = None _workflow = None
graph = None _graph = None
contents = await file.read() contents = await file.read()
try: try:
@ -62,27 +65,27 @@ async def upload_image(
# TODO: retain non-invokeai metadata on upload? # TODO: retain non-invokeai metadata on upload?
# attempt to parse metadata from image # attempt to parse metadata from image
metadata_raw = pil_image.info.get("invokeai_metadata", None) metadata_raw = metadata if isinstance(metadata, str) else pil_image.info.get("invokeai_metadata", None)
if isinstance(metadata_raw, str): if isinstance(metadata_raw, str):
metadata = metadata_raw _metadata = metadata_raw
else: else:
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image") ApiDependencies.invoker.services.logger.debug("Failed to parse metadata for uploaded image")
pass pass
# attempt to parse workflow from image # attempt to parse workflow from image
workflow_raw = pil_image.info.get("invokeai_workflow", None) workflow_raw = pil_image.info.get("invokeai_workflow", None)
if isinstance(workflow_raw, str): if isinstance(workflow_raw, str):
workflow = workflow_raw _workflow = workflow_raw
else: else:
ApiDependencies.invoker.services.logger.warn("Failed to parse workflow for uploaded image") ApiDependencies.invoker.services.logger.debug("Failed to parse workflow for uploaded image")
pass pass
# attempt to extract graph from image # attempt to extract graph from image
graph_raw = pil_image.info.get("invokeai_graph", None) graph_raw = pil_image.info.get("invokeai_graph", None)
if isinstance(graph_raw, str): if isinstance(graph_raw, str):
graph = graph_raw _graph = graph_raw
else: else:
ApiDependencies.invoker.services.logger.warn("Failed to parse graph for uploaded image") ApiDependencies.invoker.services.logger.debug("Failed to parse graph for uploaded image")
pass pass
try: try:
@ -92,9 +95,9 @@ async def upload_image(
image_category=image_category, image_category=image_category,
session_id=session_id, session_id=session_id,
board_id=board_id, board_id=board_id,
metadata=metadata, metadata=_metadata,
workflow=workflow, workflow=_workflow,
graph=graph, graph=_graph,
is_intermediate=is_intermediate, is_intermediate=is_intermediate,
) )

View File

@ -17,7 +17,7 @@ from starlette.exceptions import HTTPException
from typing_extensions import Annotated from typing_extensions import Annotated
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
from invokeai.app.services.model_install import ModelInstallJob from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.model_records import ( from invokeai.app.services.model_records import (
DuplicateModelException, DuplicateModelException,
InvalidModelException, InvalidModelException,

View File

@ -203,6 +203,7 @@ async def get_batch_status(
responses={ responses={
200: {"model": SessionQueueItem}, 200: {"model": SessionQueueItem},
}, },
response_model_exclude_none=True,
) )
async def get_queue_item( async def get_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"), queue_id: str = Path(description="The queue id to perform this operation on"),

View File

@ -1,66 +1,125 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any
from fastapi import FastAPI from fastapi import FastAPI
from fastapi_events.handlers.local import local_handler from pydantic import BaseModel
from fastapi_events.typing import Event
from socketio import ASGIApp, AsyncServer from socketio import ASGIApp, AsyncServer
from ..services.events.events_base import EventServiceBase from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
BulkDownloadCompleteEvent,
BulkDownloadErrorEvent,
BulkDownloadEventBase,
BulkDownloadStartedEvent,
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadEventBase,
DownloadProgressEvent,
DownloadStartedEvent,
FastAPIEvent,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationStartedEvent,
ModelEventBase,
ModelInstallCancelledEvent,
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallErrorEvent,
ModelInstallStartedEvent,
ModelLoadCompleteEvent,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueEventBase,
QueueItemStatusChangedEvent,
register_events,
)
class QueueSubscriptionEvent(BaseModel):
"""Event data for subscribing to the socket.io queue room.
This is a pydantic model to ensure the data is in the correct format."""
queue_id: str
class BulkDownloadSubscriptionEvent(BaseModel):
"""Event data for subscribing to the socket.io bulk downloads room.
This is a pydantic model to ensure the data is in the correct format."""
bulk_download_id: str
QUEUE_EVENTS = {
InvocationStartedEvent,
InvocationDenoiseProgressEvent,
InvocationCompleteEvent,
InvocationErrorEvent,
QueueItemStatusChangedEvent,
BatchEnqueuedEvent,
QueueClearedEvent,
}
MODEL_EVENTS = {
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
ModelLoadStartedEvent,
ModelLoadCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallStartedEvent,
ModelInstallCompleteEvent,
ModelInstallCancelledEvent,
ModelInstallErrorEvent,
}
BULK_DOWNLOAD_EVENTS = {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent}
class SocketIO: class SocketIO:
__sio: AsyncServer _sub_queue = "subscribe_queue"
__app: ASGIApp _unsub_queue = "unsubscribe_queue"
__sub_queue: str = "subscribe_queue" _sub_bulk_download = "subscribe_bulk_download"
__unsub_queue: str = "unsubscribe_queue" _unsub_bulk_download = "unsubscribe_bulk_download"
__sub_bulk_download: str = "subscribe_bulk_download"
__unsub_bulk_download: str = "unsubscribe_bulk_download"
def __init__(self, app: FastAPI): def __init__(self, app: FastAPI):
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*") self._sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io") self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io")
app.mount("/ws", self.__app) app.mount("/ws", self._app)
self.__sio.on(self.__sub_queue, handler=self._handle_sub_queue) self._sio.on(self._sub_queue, handler=self._handle_sub_queue)
self.__sio.on(self.__unsub_queue, handler=self._handle_unsub_queue) self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event) self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download)
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event) self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download)
self.__sio.on(self.__sub_bulk_download, handler=self._handle_sub_bulk_download) register_events(QUEUE_EVENTS, self._handle_queue_event)
self.__sio.on(self.__unsub_bulk_download, handler=self._handle_unsub_bulk_download) register_events(MODEL_EVENTS, self._handle_model_event)
local_handler.register(event_name=EventServiceBase.bulk_download_event, _func=self._handle_bulk_download_event) register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event)
async def _handle_queue_event(self, event: Event): async def _handle_sub_queue(self, sid: str, data: Any) -> None:
await self.__sio.emit( await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id)
event=event[1]["event"],
data=event[1]["data"],
room=event[1]["data"]["queue_id"],
)
async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None: async def _handle_unsub_queue(self, sid: str, data: Any) -> None:
if "queue_id" in data: await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id)
await self.__sio.enter_room(sid, data["queue_id"])
async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None: async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None:
if "queue_id" in data: await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
await self.__sio.leave_room(sid, data["queue_id"])
async def _handle_model_event(self, event: Event) -> None: async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"]) await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_bulk_download_event(self, event: Event): async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
await self.__sio.emit( await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id)
event=event[1]["event"],
data=event[1]["data"],
room=event[1]["data"]["bulk_download_id"],
)
async def _handle_sub_bulk_download(self, sid, data, *args, **kwargs): async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase | DownloadEventBase]) -> None:
if "bulk_download_id" in data: await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"))
await self.__sio.enter_room(sid, data["bulk_download_id"])
async def _handle_unsub_bulk_download(self, sid, data, *args, **kwargs): async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None:
if "bulk_download_id" in data: await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].bulk_download_id)
await self.__sio.leave_room(sid, data["bulk_download_id"])

View File

@ -27,6 +27,7 @@ import invokeai.frontend.web as web_dir
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
@ -182,23 +183,14 @@ def custom_openapi() -> dict[str, Any]:
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type()) openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type())
invoker_schema["class"] = "invocation" invoker_schema["class"] = "invocation"
# This code no longer seems to be necessary? # Add all event schemas
# Leave it here just in case for event in sorted(EventBase.get_events(), key=lambda e: e.__name__):
# json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
# from invokeai.backend.model_manager import get_model_config_formats if "$defs" in json_schema:
# formats = get_model_config_formats() for schema_key, schema in json_schema["$defs"].items():
# for model_config_name, enum_set in formats.items(): openapi_schema["components"]["schemas"][schema_key] = schema
del json_schema["$defs"]
# if model_config_name in openapi_schema["components"]["schemas"]: openapi_schema["components"]["schemas"][event.__name__] = json_schema
# # print(f"Config with name {name} already defined")
# continue
# openapi_schema["components"]["schemas"][model_config_name] = {
# "title": model_config_name,
# "description": "An enumeration.",
# "type": "string",
# "enum": [v.value for v in enum_set],
# }
app.openapi_schema = openapi_schema app.openapi_schema = openapi_schema
return app.openapi_schema return app.openapi_schema

View File

@ -65,11 +65,7 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(self.clip.tokenizer) tokenizer_info = context.models.load(self.clip.tokenizer)
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(self.clip.text_encoder) text_encoder_info = context.models.load(self.clip.text_encoder)
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, CLIPTextModel)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras: for lora in self.clip.loras:
@ -84,19 +80,21 @@ class CompelInvocation(BaseInvocation):
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context) ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
with ( with (
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as ( # apply all patches while the model is on the target device
tokenizer,
ti_manager,
),
text_encoder_info as text_encoder, text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching. tokenizer_info as tokenizer,
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers), ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
patched_tokenizer,
ti_manager,
),
): ):
assert isinstance(text_encoder, CLIPTextModel) assert isinstance(text_encoder, CLIPTextModel)
assert isinstance(tokenizer, CLIPTokenizer)
compel = Compel( compel = Compel(
tokenizer=tokenizer, tokenizer=patched_tokenizer,
text_encoder=text_encoder, text_encoder=text_encoder,
textual_inversion_manager=ti_manager, textual_inversion_manager=ti_manager,
dtype_for_device_getter=TorchDevice.choose_torch_dtype, dtype_for_device_getter=TorchDevice.choose_torch_dtype,
@ -106,7 +104,7 @@ class CompelInvocation(BaseInvocation):
conjunction = Compel.parse_prompt_string(self.prompt) conjunction = Compel.parse_prompt_string(self.prompt)
if context.config.get().log_tokenization: if context.config.get().log_tokenization:
log_tokenization_for_conjunction(conjunction, tokenizer) log_tokenization_for_conjunction(conjunction, patched_tokenizer)
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction) c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
@ -136,11 +134,7 @@ class SDXLPromptInvocationBase:
zero_on_empty: bool, zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
tokenizer_info = context.models.load(clip_field.tokenizer) tokenizer_info = context.models.load(clip_field.tokenizer)
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(clip_field.text_encoder) text_encoder_info = context.models.load(clip_field.text_encoder)
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
# return zero on empty # return zero on empty
if prompt == "" and zero_on_empty: if prompt == "" and zero_on_empty:
@ -177,20 +171,23 @@ class SDXLPromptInvocationBase:
ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context) ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context)
with ( with (
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as ( # apply all patches while the model is on the target device
tokenizer,
ti_manager,
),
text_encoder_info as text_encoder, text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching. tokenizer_info as tokenizer,
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers), ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
patched_tokenizer,
ti_manager,
),
): ):
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)) assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
assert isinstance(tokenizer, CLIPTokenizer)
text_encoder = cast(CLIPTextModel, text_encoder) text_encoder = cast(CLIPTextModel, text_encoder)
compel = Compel( compel = Compel(
tokenizer=tokenizer, tokenizer=patched_tokenizer,
text_encoder=text_encoder, text_encoder=text_encoder,
textual_inversion_manager=ti_manager, textual_inversion_manager=ti_manager,
dtype_for_device_getter=TorchDevice.choose_torch_dtype, dtype_for_device_getter=TorchDevice.choose_torch_dtype,
@ -203,7 +200,7 @@ class SDXLPromptInvocationBase:
if context.config.get().log_tokenization: if context.config.get().log_tokenization:
# TODO: better logging for and syntax # TODO: better logging for and syntax
log_tokenization_for_conjunction(conjunction, tokenizer) log_tokenization_for_conjunction(conjunction, patched_tokenizer)
# TODO: ask for optimizations? to not run text_encoder twice # TODO: ask for optimizations? to not run text_encoder twice
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction) c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)

View File

@ -25,7 +25,6 @@ from pydantic import BaseModel, Field, field_validator, model_validator
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
FieldDescriptions, FieldDescriptions,
ImageField, ImageField,
Input,
InputField, InputField,
OutputField, OutputField,
UIType, UIType,
@ -82,13 +81,13 @@ class ControlOutput(BaseInvocationOutput):
control: ControlField = OutputField(description=FieldDescriptions.control) control: ControlField = OutputField(description=FieldDescriptions.control)
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.1") @invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.2")
class ControlNetInvocation(BaseInvocation): class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes""" """Collects ControlNet info to pass to other nodes"""
image: ImageField = InputField(description="The control image") image: ImageField = InputField(description="The control image")
control_model: ModelIdentifierField = InputField( control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, input=Input.Direct, ui_type=UIType.ControlNetModel description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
) )
control_weight: Union[float, List[float]] = InputField( control_weight: Union[float, List[float]] = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet" default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"

View File

@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self from typing_extensions import Self
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, TensorField, UIType from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, TensorField, UIType
from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
@ -58,7 +58,7 @@ class IPAdapterOutput(BaseInvocationOutput):
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"} CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.0") @invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.1")
class IPAdapterInvocation(BaseInvocation): class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes.""" """Collects IP-Adapter info to pass to other nodes."""
@ -67,7 +67,6 @@ class IPAdapterInvocation(BaseInvocation):
ip_adapter_model: ModelIdentifierField = InputField( ip_adapter_model: ModelIdentifierField = InputField(
description="The IP-Adapter model.", description="The IP-Adapter model.",
title="IP-Adapter Model", title="IP-Adapter Model",
input=Input.Direct,
ui_order=-1, ui_order=-1,
ui_type=UIType.IPAdapterModel, ui_type=UIType.IPAdapterModel,
) )

View File

@ -930,9 +930,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
assert isinstance(unet_info.model, UNet2DConditionModel) assert isinstance(unet_info.model, UNet2DConditionModel)
with ( with (
ExitStack() as exit_stack, ExitStack() as exit_stack,
ModelPatcher.apply_freeu(unet_info.model, self.unet.freeu_config),
set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME
unet_info as unet, unet_info as unet,
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
set_seamless(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching. # Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(unet, _lora_loader()), ModelPatcher.apply_lora_unet(unet, _lora_loader()),
): ):

View File

@ -11,6 +11,7 @@ from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType,
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
Classification,
invocation, invocation,
invocation_output, invocation_output,
) )
@ -93,19 +94,46 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
pass pass
@invocation_output("model_identifier_output")
class ModelIdentifierOutput(BaseInvocationOutput):
"""Model identifier output"""
model: ModelIdentifierField = OutputField(description="Model identifier", title="Model")
@invocation(
"model_identifier",
title="Model identifier",
tags=["model"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class ModelIdentifierInvocation(BaseInvocation):
"""Selects any model, outputting it its identifier. Be careful with this one! The identifier will be accepted as
input for any model, even if the model types don't match. If you connect this to a mismatched input, you'll get an
error."""
model: ModelIdentifierField = InputField(description="The model to select", title="Model")
def invoke(self, context: InvocationContext) -> ModelIdentifierOutput:
if not context.models.exists(self.model.key):
raise Exception(f"Unknown model {self.model.key}")
return ModelIdentifierOutput(model=self.model)
@invocation( @invocation(
"main_model_loader", "main_model_loader",
title="Main Model", title="Main Model",
tags=["model"], tags=["model"],
category="model", category="model",
version="1.0.2", version="1.0.3",
) )
class MainModelLoaderInvocation(BaseInvocation): class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels.""" """Loads a main model, outputting its submodels."""
model: ModelIdentifierField = InputField( model: ModelIdentifierField = InputField(description=FieldDescriptions.main_model, ui_type=UIType.MainModel)
description=FieldDescriptions.main_model, input=Input.Direct, ui_type=UIType.MainModel
)
# TODO: precision? # TODO: precision?
def invoke(self, context: InvocationContext) -> ModelLoaderOutput: def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
@ -134,12 +162,12 @@ class LoRALoaderOutput(BaseInvocationOutput):
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.2") @invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.3")
class LoRALoaderInvocation(BaseInvocation): class LoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder.""" """Apply selected lora to unet and text_encoder."""
lora: ModelIdentifierField = InputField( lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
) )
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField( unet: Optional[UNetField] = InputField(
@ -197,12 +225,12 @@ class LoRASelectorOutput(BaseInvocationOutput):
lora: LoRAField = OutputField(description="LoRA model and weight", title="LoRA") lora: LoRAField = OutputField(description="LoRA model and weight", title="LoRA")
@invocation("lora_selector", title="LoRA Selector", tags=["model"], category="model", version="1.0.0") @invocation("lora_selector", title="LoRA Selector", tags=["model"], category="model", version="1.0.1")
class LoRASelectorInvocation(BaseInvocation): class LoRASelectorInvocation(BaseInvocation):
"""Selects a LoRA model and weight.""" """Selects a LoRA model and weight."""
lora: ModelIdentifierField = InputField( lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
) )
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
@ -273,13 +301,13 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
title="SDXL LoRA", title="SDXL LoRA",
tags=["lora", "model"], tags=["lora", "model"],
category="model", category="model",
version="1.0.2", version="1.0.3",
) )
class SDXLLoRALoaderInvocation(BaseInvocation): class SDXLLoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder.""" """Apply selected lora to unet and text_encoder."""
lora: ModelIdentifierField = InputField( lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
) )
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField( unet: Optional[UNetField] = InputField(
@ -414,12 +442,12 @@ class SDXLLoRACollectionLoader(BaseInvocation):
return output return output
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.2") @invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.3")
class VAELoaderInvocation(BaseInvocation): class VAELoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput""" """Loads a VAE model, outputting a VaeLoaderOutput"""
vae_model: ModelIdentifierField = InputField( vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, input=Input.Direct, title="VAE", ui_type=UIType.VAEModel description=FieldDescriptions.vae_model, title="VAE", ui_type=UIType.VAEModel
) )
def invoke(self, context: InvocationContext) -> VAEOutput: def invoke(self, context: InvocationContext) -> VAEOutput:

View File

@ -1,4 +1,4 @@
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import SubModelType from invokeai.backend.model_manager import SubModelType
@ -30,12 +30,12 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE") vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.2") @invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.3")
class SDXLModelLoaderInvocation(BaseInvocation): class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels.""" """Loads an sdxl base model, outputting its submodels."""
model: ModelIdentifierField = InputField( model: ModelIdentifierField = InputField(
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel description=FieldDescriptions.sdxl_main_model, ui_type=UIType.SDXLMainModel
) )
# TODO: precision? # TODO: precision?
@ -67,13 +67,13 @@ class SDXLModelLoaderInvocation(BaseInvocation):
title="SDXL Refiner Model", title="SDXL Refiner Model",
tags=["model", "sdxl", "refiner"], tags=["model", "sdxl", "refiner"],
category="model", category="model",
version="1.0.2", version="1.0.3",
) )
class SDXLRefinerModelLoaderInvocation(BaseInvocation): class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels.""" """Loads an sdxl refiner model, outputting its submodels."""
model: ModelIdentifierField = InputField( model: ModelIdentifierField = InputField(
description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel description=FieldDescriptions.sdxl_refiner_model, ui_type=UIType.SDXLRefinerModel
) )
# TODO: precision? # TODO: precision?

View File

@ -8,7 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation, invocation,
invocation_output, invocation_output,
) )
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
@ -45,7 +45,7 @@ class T2IAdapterOutput(BaseInvocationOutput):
@invocation( @invocation(
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.2" "t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.3"
) )
class T2IAdapterInvocation(BaseInvocation): class T2IAdapterInvocation(BaseInvocation):
"""Collects T2I-Adapter info to pass to other nodes.""" """Collects T2I-Adapter info to pass to other nodes."""
@ -55,7 +55,6 @@ class T2IAdapterInvocation(BaseInvocation):
t2i_adapter_model: ModelIdentifierField = InputField( t2i_adapter_model: ModelIdentifierField = InputField(
description="The T2I-Adapter model.", description="The T2I-Adapter model.",
title="T2I-Adapter Model", title="T2I-Adapter Model",
input=Input.Direct,
ui_order=-1, ui_order=-1,
ui_type=UIType.T2IAdapterModel, ui_type=UIType.T2IAdapterModel,
) )

View File

@ -106,9 +106,7 @@ class BulkDownloadService(BulkDownloadBase):
if self._invoker: if self._invoker:
assert bulk_download_id is not None assert bulk_download_id is not None
self._invoker.services.events.emit_bulk_download_started( self._invoker.services.events.emit_bulk_download_started(
bulk_download_id=bulk_download_id, bulk_download_id, bulk_download_item_id, bulk_download_item_name
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
) )
def _signal_job_completed( def _signal_job_completed(
@ -118,10 +116,8 @@ class BulkDownloadService(BulkDownloadBase):
if self._invoker: if self._invoker:
assert bulk_download_id is not None assert bulk_download_id is not None
assert bulk_download_item_name is not None assert bulk_download_item_name is not None
self._invoker.services.events.emit_bulk_download_completed( self._invoker.services.events.emit_bulk_download_complete(
bulk_download_id=bulk_download_id, bulk_download_id, bulk_download_item_id, bulk_download_item_name
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
) )
def _signal_job_failed( def _signal_job_failed(
@ -131,11 +127,8 @@ class BulkDownloadService(BulkDownloadBase):
if self._invoker: if self._invoker:
assert bulk_download_id is not None assert bulk_download_id is not None
assert exception is not None assert exception is not None
self._invoker.services.events.emit_bulk_download_failed( self._invoker.services.events.emit_bulk_download_error(
bulk_download_id=bulk_download_id, bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception)
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
error=str(exception),
) )
def stop(self, *args, **kwargs): def stop(self, *args, **kwargs):

View File

@ -8,7 +8,7 @@ import time
import traceback import traceback
from pathlib import Path from pathlib import Path
from queue import Empty, PriorityQueue from queue import Empty, PriorityQueue
from typing import Any, Dict, List, Optional, Set from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
import requests import requests
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
@ -34,6 +34,9 @@ from .download_base import (
UnknownJobIDException, UnknownJobIDException,
) )
if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
# Maximum number of bytes to download during each call to requests.iter_content() # Maximum number of bytes to download during each call to requests.iter_content()
DOWNLOAD_CHUNK_SIZE = 100000 DOWNLOAD_CHUNK_SIZE = 100000
@ -45,7 +48,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
self, self,
max_parallel_dl: int = 5, max_parallel_dl: int = 5,
app_config: Optional[InvokeAIAppConfig] = None, app_config: Optional[InvokeAIAppConfig] = None,
event_bus: Optional[EventServiceBase] = None, event_bus: Optional["EventServiceBase"] = None,
requests_session: Optional[requests.sessions.Session] = None, requests_session: Optional[requests.sessions.Session] = None,
): ):
""" """
@ -408,28 +411,18 @@ class DownloadQueueService(DownloadQueueServiceBase):
job.status = DownloadJobStatus.RUNNING job.status = DownloadJobStatus.RUNNING
self._execute_cb(job, "on_start") self._execute_cb(job, "on_start")
if self._event_bus: if self._event_bus:
assert job.download_path self._event_bus.emit_download_started(job)
self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix())
def _signal_job_progress(self, job: DownloadJob) -> None: def _signal_job_progress(self, job: DownloadJob) -> None:
self._execute_cb(job, "on_progress") self._execute_cb(job, "on_progress")
if self._event_bus: if self._event_bus:
assert job.download_path self._event_bus.emit_download_progress(job)
self._event_bus.emit_download_progress(
str(job.source),
download_path=job.download_path.as_posix(),
current_bytes=job.bytes,
total_bytes=job.total_bytes,
)
def _signal_job_complete(self, job: DownloadJob) -> None: def _signal_job_complete(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.COMPLETED job.status = DownloadJobStatus.COMPLETED
self._execute_cb(job, "on_complete") self._execute_cb(job, "on_complete")
if self._event_bus: if self._event_bus:
assert job.download_path self._event_bus.emit_download_complete(job)
self._event_bus.emit_download_complete(
str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes
)
def _signal_job_cancelled(self, job: DownloadJob) -> None: def _signal_job_cancelled(self, job: DownloadJob) -> None:
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]: if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
@ -437,7 +430,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
job.status = DownloadJobStatus.CANCELLED job.status = DownloadJobStatus.CANCELLED
self._execute_cb(job, "on_cancelled") self._execute_cb(job, "on_cancelled")
if self._event_bus: if self._event_bus:
self._event_bus.emit_download_cancelled(str(job.source)) self._event_bus.emit_download_cancelled(job)
# if multifile download, then signal the parent # if multifile download, then signal the parent
if parent_job := self._download_part2parent.get(job.source, None): if parent_job := self._download_part2parent.get(job.source, None):
@ -451,9 +444,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
self._execute_cb(job, "on_error", excp) self._execute_cb(job, "on_error", excp)
if self._event_bus: if self._event_bus:
assert job.error_type self._event_bus.emit_download_error(job)
assert job.error
self._event_bus.emit_download_error(str(job.source), error_type=job.error_type, error=job.error)
def _cleanup_cancelled_job(self, job: DownloadJob) -> None: def _cleanup_cancelled_job(self, job: DownloadJob) -> None:
self._logger.debug(f"Cleaning up leftover files from cancelled download job {job.download_path}") self._logger.debug(f"Cleaning up leftover files from cancelled download job {job.download_path}")

View File

@ -1,490 +1,195 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Optional
from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.events.events_common import (
from invokeai.app.services.session_queue.session_queue_common import ( BatchEnqueuedEvent,
BatchStatus, BulkDownloadCompleteEvent,
EnqueueBatchResult, BulkDownloadErrorEvent,
SessionQueueItem, BulkDownloadStartedEvent,
SessionQueueStatus, DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
EventBase,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationStartedEvent,
ModelInstallCancelledEvent,
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallErrorEvent,
ModelInstallStartedEvent,
ModelLoadCompleteEvent,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueItemStatusChangedEvent,
) )
from invokeai.app.util.misc import get_timestamp from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.model_manager import AnyModelConfig
from invokeai.backend.model_manager.config import SubModelType if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.download.download_base import DownloadJob
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
BatchStatus,
EnqueueBatchResult,
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
class EventServiceBase: class EventServiceBase:
queue_event: str = "queue_event"
bulk_download_event: str = "bulk_download_event"
download_event: str = "download_event"
model_event: str = "model_event"
"""Basic event bus, to have an empty stand-in when not needed""" """Basic event bus, to have an empty stand-in when not needed"""
def dispatch(self, event_name: str, payload: Any) -> None: def dispatch(self, event: "EventBase") -> None:
pass pass
def _emit_bulk_download_event(self, event_name: str, payload: dict) -> None: # region: Invocation
"""Bulk download events are emitted to a room with queue_id as the room name"""
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.bulk_download_event,
payload={"event": event_name, "data": payload},
)
def __emit_queue_event(self, event_name: str, payload: dict) -> None: def emit_invocation_started(self, queue_item: "SessionQueueItem", invocation: "BaseInvocation") -> None:
"""Queue events are emitted to a room with queue_id as the room name""" """Emitted when an invocation is started"""
payload["timestamp"] = get_timestamp() self.dispatch(InvocationStartedEvent.build(queue_item, invocation))
self.dispatch(
event_name=EventServiceBase.queue_event,
payload={"event": event_name, "data": payload},
)
def __emit_download_event(self, event_name: str, payload: dict) -> None: def emit_invocation_denoise_progress(
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.download_event,
payload={"event": event_name, "data": payload},
)
def __emit_model_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.model_event,
payload={"event": event_name, "data": payload},
)
# Define events here for every event in the system.
# This will make them easier to integrate until we find a schema generator.
def emit_generator_progress(
self, self,
queue_id: str, queue_item: "SessionQueueItem",
queue_item_id: int, invocation: "BaseInvocation",
queue_batch_id: str, intermediate_state: PipelineIntermediateState,
graph_execution_state_id: str, progress_image: "ProgressImage",
node_id: str,
source_node_id: str,
progress_image: Optional[ProgressImage],
step: int,
order: int,
total_steps: int,
) -> None: ) -> None:
"""Emitted when there is generation progress""" """Emitted at each step during denoising of an invocation."""
self.__emit_queue_event( self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image))
event_name="generator_progress",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node_id": node_id,
"source_node_id": source_node_id,
"progress_image": progress_image.model_dump(mode="json") if progress_image is not None else None,
"step": step,
"order": order,
"total_steps": total_steps,
},
)
def emit_invocation_complete( def emit_invocation_complete(
self, self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput"
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
result: dict,
node: dict,
source_node_id: str,
) -> None: ) -> None:
"""Emitted when an invocation has completed""" """Emitted when an invocation is complete"""
self.__emit_queue_event( self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output))
event_name="invocation_complete",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
"result": result,
},
)
def emit_invocation_error( def emit_invocation_error(
self, self,
queue_id: str, queue_item: "SessionQueueItem",
queue_item_id: int, invocation: "BaseInvocation",
queue_batch_id: str,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
error_type: str, error_type: str,
error: str, error_message: str,
user_id: str | None, error_traceback: str,
project_id: str | None,
) -> None: ) -> None:
"""Emitted when an invocation has completed""" """Emitted when an invocation encounters an error"""
self.__emit_queue_event( self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error_message, error_traceback))
event_name="invocation_error",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
"error_type": error_type,
"error": error,
"user_id": user_id,
"project_id": project_id,
},
)
def emit_invocation_started( # endregion
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
) -> None:
"""Emitted when an invocation has started"""
self.__emit_queue_event(
event_name="invocation_started",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
},
)
def emit_graph_execution_complete( # region Queue
self, queue_id: str, queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str
) -> None:
"""Emitted when a session has completed all invocations"""
self.__emit_queue_event(
event_name="graph_execution_state_complete",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
},
)
def emit_model_load_started(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Emitted when a model is requested"""
self.__emit_queue_event(
event_name="model_load_started",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_config": model_config.model_dump(mode="json"),
"submodel_type": submodel_type,
},
)
def emit_model_load_completed(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Emitted when a model is correctly loaded (returns model info)"""
self.__emit_queue_event(
event_name="model_load_completed",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_config": model_config.model_dump(mode="json"),
"submodel_type": submodel_type,
},
)
def emit_session_canceled(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
) -> None:
"""Emitted when a session is canceled"""
self.__emit_queue_event(
event_name="session_canceled",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
},
)
def emit_queue_item_status_changed( def emit_queue_item_status_changed(
self, self, queue_item: "SessionQueueItem", batch_status: "BatchStatus", queue_status: "SessionQueueStatus"
session_queue_item: SessionQueueItem,
batch_status: BatchStatus,
queue_status: SessionQueueStatus,
) -> None: ) -> None:
"""Emitted when a queue item's status changes""" """Emitted when a queue item's status changes"""
self.__emit_queue_event( self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status))
event_name="queue_item_status_changed",
payload={
"queue_id": queue_status.queue_id,
"queue_item": {
"queue_id": session_queue_item.queue_id,
"item_id": session_queue_item.item_id,
"status": session_queue_item.status,
"batch_id": session_queue_item.batch_id,
"session_id": session_queue_item.session_id,
"error": session_queue_item.error,
"created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None,
"updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
"completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
},
"batch_status": batch_status.model_dump(mode="json"),
"queue_status": queue_status.model_dump(mode="json"),
},
)
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None: def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None:
"""Emitted when a batch is enqueued""" """Emitted when a batch is enqueued"""
self.__emit_queue_event( self.dispatch(BatchEnqueuedEvent.build(enqueue_result))
event_name="batch_enqueued",
payload={
"queue_id": enqueue_result.queue_id,
"batch_id": enqueue_result.batch.batch_id,
"enqueued": enqueue_result.enqueued,
},
)
def emit_queue_cleared(self, queue_id: str) -> None: def emit_queue_cleared(self, queue_id: str) -> None:
"""Emitted when the queue is cleared""" """Emitted when a queue is cleared"""
self.__emit_queue_event( self.dispatch(QueueClearedEvent.build(queue_id))
event_name="queue_cleared",
payload={"queue_id": queue_id},
)
def emit_download_started(self, source: str, download_path: str) -> None: # endregion
"""
Emit when a download job is started.
:param url: The downloaded url # region Download
"""
self.__emit_download_event(
event_name="download_started",
payload={"source": source, "download_path": download_path},
)
def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None: def emit_download_started(self, job: "DownloadJob") -> None:
""" """Emitted when a download is started"""
Emit "download_progress" events at regular intervals during a download job. self.dispatch(DownloadStartedEvent.build(job))
:param source: The downloaded source def emit_download_progress(self, job: "DownloadJob") -> None:
:param download_path: The local downloaded file """Emitted at intervals during a download"""
:param current_bytes: Number of bytes downloaded so far self.dispatch(DownloadProgressEvent.build(job))
:param total_bytes: The size of the file being downloaded (if known)
"""
self.__emit_download_event(
event_name="download_progress",
payload={
"source": source,
"download_path": download_path,
"current_bytes": current_bytes,
"total_bytes": total_bytes,
},
)
def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None: def emit_download_complete(self, job: "DownloadJob") -> None:
""" """Emitted when a download is completed"""
Emit a "download_complete" event at the end of a successful download. self.dispatch(DownloadCompleteEvent.build(job))
:param source: Source URL def emit_download_cancelled(self, job: "DownloadJob") -> None:
:param download_path: Path to the locally downloaded file """Emitted when a download is cancelled"""
:param total_bytes: The size of the downloaded file self.dispatch(DownloadCancelledEvent.build(job))
"""
self.__emit_download_event(
event_name="download_complete",
payload={
"source": source,
"download_path": download_path,
"total_bytes": total_bytes,
},
)
def emit_download_cancelled(self, source: str) -> None: def emit_download_error(self, job: "DownloadJob") -> None:
"""Emit a "download_cancelled" event in the event that the download was cancelled by user.""" """Emitted when a download encounters an error"""
self.__emit_download_event( self.dispatch(DownloadErrorEvent.build(job))
event_name="download_cancelled",
payload={
"source": source,
},
)
def emit_download_error(self, source: str, error_type: str, error: str) -> None: # endregion
"""
Emit a "download_error" event when an download job encounters an exception.
:param source: Source URL # region Model loading
:param error_type: The name of the exception that raised the error
:param error: The traceback from this error
"""
self.__emit_download_event(
event_name="download_error",
payload={
"source": source,
"error_type": error_type,
"error": error,
},
)
def emit_model_install_downloading( def emit_model_load_started(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None:
self, """Emitted when a model load is started."""
source: str, self.dispatch(ModelLoadStartedEvent.build(config, submodel_type))
local_path: str,
bytes: int, def emit_model_load_complete(
total_bytes: int, self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None
parts: List[Dict[str, Union[str, int]]],
id: int,
) -> None: ) -> None:
""" """Emitted when a model load is complete."""
Emit at intervals while the install job is in progress (remote models only). self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type))
:param source: Source of the model # endregion
:param local_path: Where model is downloading to
:param parts: Progress of downloading URLs that comprise the model, if any.
:param bytes: Number of bytes downloaded so far.
:param total_bytes: Total size of download, including all files.
This emits a Dict with keys "source", "local_path", "bytes" and "total_bytes".
"""
self.__emit_model_event(
event_name="model_install_downloading",
payload={
"source": source,
"local_path": local_path,
"bytes": bytes,
"total_bytes": total_bytes,
"parts": parts,
"id": id,
},
)
def emit_model_install_downloads_done(self, source: str) -> None: # region Model install
"""
Emit once when all parts are downloaded, but before the probing and registration start.
:param source: Source of the model; local path, repo_id or url def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None:
""" """Emitted at intervals while the install job is in progress (remote models only)."""
self.__emit_model_event( self.dispatch(ModelInstallDownloadProgressEvent.build(job))
event_name="model_install_downloads_done",
payload={"source": source},
)
def emit_model_install_running(self, source: str) -> None: def emit_model_install_downloads_complete(self, job: "ModelInstallJob") -> None:
""" self.dispatch(ModelInstallDownloadsCompleteEvent.build(job))
Emit once when an install job becomes active.
:param source: Source of the model; local path, repo_id or url def emit_model_install_started(self, job: "ModelInstallJob") -> None:
""" """Emitted once when an install job is started (after any download)."""
self.__emit_model_event( self.dispatch(ModelInstallStartedEvent.build(job))
event_name="model_install_running",
payload={"source": source},
)
def emit_model_install_completed(self, source: str, key: str, id: int, total_bytes: Optional[int] = None) -> None: def emit_model_install_complete(self, job: "ModelInstallJob") -> None:
""" """Emitted when an install job is completed successfully."""
Emit when an install job is completed successfully. self.dispatch(ModelInstallCompleteEvent.build(job))
:param source: Source of the model; local path, repo_id or url def emit_model_install_cancelled(self, job: "ModelInstallJob") -> None:
:param key: Model config record key """Emitted when an install job is cancelled."""
:param total_bytes: Size of the model (may be None for installation of a local path) self.dispatch(ModelInstallCancelledEvent.build(job))
"""
self.__emit_model_event(
event_name="model_install_completed",
payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id},
)
def emit_model_install_cancelled(self, source: str, id: int) -> None: def emit_model_install_error(self, job: "ModelInstallJob") -> None:
""" """Emitted when an install job encounters an exception."""
Emit when an install job is cancelled. self.dispatch(ModelInstallErrorEvent.build(job))
:param source: Source of the model; local path, repo_id or url # endregion
"""
self.__emit_model_event(
event_name="model_install_cancelled",
payload={"source": source, "id": id},
)
def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None: # region Bulk image download
"""
Emit when an install job encounters an exception.
:param source: Source of the model
:param error_type: The name of the exception
:param error: A text description of the exception
"""
self.__emit_model_event(
event_name="model_install_error",
payload={"source": source, "error_type": error_type, "error": error, "id": id},
)
def emit_bulk_download_started( def emit_bulk_download_started(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None: ) -> None:
"""Emitted when a bulk download starts""" """Emitted when a bulk image download is started"""
self._emit_bulk_download_event( self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
event_name="bulk_download_started",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
},
)
def emit_bulk_download_completed( def emit_bulk_download_complete(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None: ) -> None:
"""Emitted when a bulk download completes""" """Emitted when a bulk image download is complete"""
self._emit_bulk_download_event( self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
event_name="bulk_download_completed",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
},
)
def emit_bulk_download_failed( def emit_bulk_download_error(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
) -> None: ) -> None:
"""Emitted when a bulk download fails""" """Emitted when a bulk image download has an error"""
self._emit_bulk_download_event( self.dispatch(
event_name="bulk_download_failed", BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error)
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
"error": error,
},
) )
# endregion

View File

@ -0,0 +1,591 @@
from math import floor
from typing import TYPE_CHECKING, Any, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar
from fastapi_events.handlers.local import local_handler
from fastapi_events.registry.payload_schema import registry as payload_schema
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
BatchStatus,
EnqueueBatchResult,
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
if TYPE_CHECKING:
from invokeai.app.services.download.download_base import DownloadJob
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
class EventBase(BaseModel):
"""Base class for all events. All events must inherit from this class.
Events must define a class attribute `__event_name__` to identify the event.
All other attributes should be defined as normal for a pydantic model.
A timestamp is automatically added to the event when it is created.
"""
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
@classmethod
def get_events(cls) -> set[type["EventBase"]]:
"""Get a set of all event models."""
event_subclasses: set[type["EventBase"]] = set()
for subclass in cls.__subclasses__():
# We only want to include subclasses that are event models, not intermediary classes
if hasattr(subclass, "__event_name__"):
event_subclasses.add(subclass)
event_subclasses.update(subclass.get_events())
return event_subclasses
TEvent = TypeVar("TEvent", bound=EventBase, contravariant=True)
FastAPIEvent: TypeAlias = tuple[str, TEvent]
"""
A tuple representing a `fastapi-events` event, with the event name and payload.
Provide a generic type to `TEvent` to specify the payload type.
"""
class FastAPIEventFunc(Protocol, Generic[TEvent]):
def __call__(self, event: FastAPIEvent[TEvent]) -> Optional[Coroutine[Any, Any, None]]: ...
def register_events(events: set[type[TEvent]] | type[TEvent], func: FastAPIEventFunc[TEvent]) -> None:
"""Register a function to handle specific events.
:param events: An event or set of events to handle
:param func: The function to handle the events
"""
events = events if isinstance(events, set) else {events}
for event in events:
assert hasattr(event, "__event_name__")
local_handler.register(event_name=event.__event_name__, _func=func) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
class QueueEventBase(EventBase):
"""Base class for queue events"""
queue_id: str = Field(description="The ID of the queue")
class QueueItemEventBase(QueueEventBase):
"""Base class for queue item events"""
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
class InvocationEventBase(QueueItemEventBase):
"""Base class for invocation events"""
session_id: str = Field(description="The ID of the session (aka graph execution state)")
queue_id: str = Field(description="The ID of the queue")
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
session_id: str = Field(description="The ID of the session (aka graph execution state)")
invocation: SerializeAsAny[BaseInvocation] = Field(description="The ID of the invocation")
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
@payload_schema.register
class InvocationStartedEvent(InvocationEventBase):
"""Event model for invocation_started"""
__event_name__ = "invocation_started"
@classmethod
def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
)
@payload_schema.register
class InvocationDenoiseProgressEvent(InvocationEventBase):
"""Event model for invocation_denoise_progress"""
__event_name__ = "invocation_denoise_progress"
progress_image: ProgressImage = Field(description="The progress image sent at each step during processing")
step: int = Field(description="The current step of the invocation")
total_steps: int = Field(description="The total number of steps in the invocation")
order: int = Field(description="The order of the invocation in the session")
percentage: float = Field(description="The percentage of completion of the invocation")
@classmethod
def build(
cls,
queue_item: SessionQueueItem,
invocation: BaseInvocation,
intermediate_state: PipelineIntermediateState,
progress_image: ProgressImage,
) -> "InvocationDenoiseProgressEvent":
step = intermediate_state.step
total_steps = intermediate_state.total_steps
order = intermediate_state.order
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
progress_image=progress_image,
step=step,
total_steps=total_steps,
order=order,
percentage=cls.calc_percentage(step, total_steps, order),
)
@staticmethod
def calc_percentage(step: int, total_steps: int, scheduler_order: float) -> float:
"""Calculate the percentage of completion of denoising."""
if total_steps == 0:
return 0.0
if scheduler_order == 2:
return floor((step + 1 + 1) / 2) / floor((total_steps + 1) / 2)
# order == 1
return (step + 1 + 1) / (total_steps + 1)
@payload_schema.register
class InvocationCompleteEvent(InvocationEventBase):
"""Event model for invocation_complete"""
__event_name__ = "invocation_complete"
result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation")
@classmethod
def build(
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput
) -> "InvocationCompleteEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
result=result,
)
@payload_schema.register
class InvocationErrorEvent(InvocationEventBase):
"""Event model for invocation_error"""
__event_name__ = "invocation_error"
error_type: str = Field(description="The error type")
error_message: str = Field(description="The error message")
error_traceback: str = Field(description="The error traceback")
user_id: Optional[str] = Field(default=None, description="The ID of the user who created the invocation")
project_id: Optional[str] = Field(default=None, description="The ID of the user who created the invocation")
@classmethod
def build(
cls,
queue_item: SessionQueueItem,
invocation: BaseInvocation,
error_type: str,
error_message: str,
error_traceback: str,
) -> "InvocationErrorEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
user_id=getattr(queue_item, "user_id", None),
project_id=getattr(queue_item, "project_id", None),
)
@payload_schema.register
class QueueItemStatusChangedEvent(QueueItemEventBase):
"""Event model for queue_item_status_changed"""
__event_name__ = "queue_item_status_changed"
status: QUEUE_ITEM_STATUS = Field(description="The new status of the queue item")
error_type: Optional[str] = Field(default=None, description="The error type, if any")
error_message: Optional[str] = Field(default=None, description="The error message, if any")
error_traceback: Optional[str] = Field(default=None, description="The error traceback, if any")
created_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was created")
updated_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was last updated")
started_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was started")
completed_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was completed")
batch_status: BatchStatus = Field(description="The status of the batch")
queue_status: SessionQueueStatus = Field(description="The status of the queue")
session_id: str = Field(description="The ID of the session (aka graph execution state)")
@classmethod
def build(
cls, queue_item: SessionQueueItem, batch_status: BatchStatus, queue_status: SessionQueueStatus
) -> "QueueItemStatusChangedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
status=queue_item.status,
error_type=queue_item.error_type,
error_message=queue_item.error_message,
error_traceback=queue_item.error_traceback,
created_at=str(queue_item.created_at) if queue_item.created_at else None,
updated_at=str(queue_item.updated_at) if queue_item.updated_at else None,
started_at=str(queue_item.started_at) if queue_item.started_at else None,
completed_at=str(queue_item.completed_at) if queue_item.completed_at else None,
batch_status=batch_status,
queue_status=queue_status,
)
@payload_schema.register
class BatchEnqueuedEvent(QueueEventBase):
"""Event model for batch_enqueued"""
__event_name__ = "batch_enqueued"
batch_id: str = Field(description="The ID of the batch")
enqueued: int = Field(description="The number of invocations enqueued")
requested: int = Field(
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
)
priority: int = Field(description="The priority of the batch")
@classmethod
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
return cls(
queue_id=enqueue_result.queue_id,
batch_id=enqueue_result.batch.batch_id,
enqueued=enqueue_result.enqueued,
requested=enqueue_result.requested,
priority=enqueue_result.priority,
)
@payload_schema.register
class QueueClearedEvent(QueueEventBase):
"""Event model for queue_cleared"""
__event_name__ = "queue_cleared"
@classmethod
def build(cls, queue_id: str) -> "QueueClearedEvent":
return cls(queue_id=queue_id)
class DownloadEventBase(EventBase):
"""Base class for events associated with a download"""
source: str = Field(description="The source of the download")
@payload_schema.register
class DownloadStartedEvent(DownloadEventBase):
"""Event model for download_started"""
__event_name__ = "download_started"
download_path: str = Field(description="The local path where the download is saved")
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadStartedEvent":
assert job.download_path
return cls(source=str(job.source), download_path=job.download_path.as_posix())
@payload_schema.register
class DownloadProgressEvent(DownloadEventBase):
"""Event model for download_progress"""
__event_name__ = "download_progress"
download_path: str = Field(description="The local path where the download is saved")
current_bytes: int = Field(description="The number of bytes downloaded so far")
total_bytes: int = Field(description="The total number of bytes to be downloaded")
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadProgressEvent":
assert job.download_path
return cls(
source=str(job.source),
download_path=job.download_path.as_posix(),
current_bytes=job.bytes,
total_bytes=job.total_bytes,
)
@payload_schema.register
class DownloadCompleteEvent(DownloadEventBase):
"""Event model for download_complete"""
__event_name__ = "download_complete"
download_path: str = Field(description="The local path where the download is saved")
total_bytes: int = Field(description="The total number of bytes downloaded")
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadCompleteEvent":
assert job.download_path
return cls(source=str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes)
@payload_schema.register
class DownloadCancelledEvent(DownloadEventBase):
"""Event model for download_cancelled"""
__event_name__ = "download_cancelled"
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadCancelledEvent":
return cls(source=str(job.source))
@payload_schema.register
class DownloadErrorEvent(DownloadEventBase):
"""Event model for download_error"""
__event_name__ = "download_error"
error_type: str = Field(description="The type of error")
error: str = Field(description="The error message")
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadErrorEvent":
assert job.error_type
assert job.error
return cls(source=str(job.source), error_type=job.error_type, error=job.error)
class ModelEventBase(EventBase):
"""Base class for events associated with a model"""
@payload_schema.register
class ModelLoadStartedEvent(ModelEventBase):
"""Event model for model_load_started"""
__event_name__ = "model_load_started"
config: AnyModelConfig = Field(description="The model's config")
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadStartedEvent":
return cls(config=config, submodel_type=submodel_type)
@payload_schema.register
class ModelLoadCompleteEvent(ModelEventBase):
"""Event model for model_load_complete"""
__event_name__ = "model_load_complete"
config: AnyModelConfig = Field(description="The model's config")
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadCompleteEvent":
return cls(config=config, submodel_type=submodel_type)
@payload_schema.register
class ModelInstallDownloadProgressEvent(ModelEventBase):
"""Event model for model_install_download_progress"""
__event_name__ = "model_install_download_progress"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
local_path: str = Field(description="Where model is downloading to")
bytes: int = Field(description="Number of bytes downloaded so far")
total_bytes: int = Field(description="Total size of download, including all files")
parts: list[dict[str, int | str]] = Field(
description="Progress of downloading URLs that comprise the model, if any"
)
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadProgressEvent":
parts: list[dict[str, str | int]] = [
{
"url": str(x.source),
"local_path": str(x.download_path),
"bytes": x.bytes,
"total_bytes": x.total_bytes,
}
for x in job.download_parts
]
return cls(
id=job.id,
source=str(job.source),
local_path=job.local_path.as_posix(),
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
)
@payload_schema.register
class ModelInstallDownloadsCompleteEvent(ModelEventBase):
"""Emitted once when an install job becomes active."""
__event_name__ = "model_install_downloads_complete"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadsCompleteEvent":
return cls(id=job.id, source=str(job.source))
@payload_schema.register
class ModelInstallStartedEvent(ModelEventBase):
"""Event model for model_install_started"""
__event_name__ = "model_install_started"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallStartedEvent":
return cls(id=job.id, source=str(job.source))
@payload_schema.register
class ModelInstallCompleteEvent(ModelEventBase):
"""Event model for model_install_complete"""
__event_name__ = "model_install_complete"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
key: str = Field(description="Model config record key")
total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallCompleteEvent":
assert job.config_out is not None
return cls(id=job.id, source=str(job.source), key=(job.config_out.key), total_bytes=job.total_bytes)
@payload_schema.register
class ModelInstallCancelledEvent(ModelEventBase):
"""Event model for model_install_cancelled"""
__event_name__ = "model_install_cancelled"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallCancelledEvent":
return cls(id=job.id, source=str(job.source))
@payload_schema.register
class ModelInstallErrorEvent(ModelEventBase):
"""Event model for model_install_error"""
__event_name__ = "model_install_error"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
error_type: str = Field(description="The name of the exception")
error: str = Field(description="A text description of the exception")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallErrorEvent":
assert job.error_type is not None
assert job.error is not None
return cls(id=job.id, source=str(job.source), error_type=job.error_type, error=job.error)
class BulkDownloadEventBase(EventBase):
"""Base class for events associated with a bulk image download"""
bulk_download_id: str = Field(description="The ID of the bulk image download")
bulk_download_item_id: str = Field(description="The ID of the bulk image download item")
bulk_download_item_name: str = Field(description="The name of the bulk image download item")
@payload_schema.register
class BulkDownloadStartedEvent(BulkDownloadEventBase):
"""Event model for bulk_download_started"""
__event_name__ = "bulk_download_started"
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> "BulkDownloadStartedEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
)
@payload_schema.register
class BulkDownloadCompleteEvent(BulkDownloadEventBase):
"""Event model for bulk_download_complete"""
__event_name__ = "bulk_download_complete"
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> "BulkDownloadCompleteEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
)
@payload_schema.register
class BulkDownloadErrorEvent(BulkDownloadEventBase):
"""Event model for bulk_download_error"""
__event_name__ = "bulk_download_error"
error: str = Field(description="The error message")
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
) -> "BulkDownloadErrorEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
error=error,
)

View File

@ -0,0 +1,47 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
import threading
from queue import Empty, Queue
from fastapi_events.dispatcher import dispatch
from invokeai.app.services.events.events_common import (
EventBase,
)
from .events_base import EventServiceBase
class FastAPIEventService(EventServiceBase):
def __init__(self, event_handler_id: int) -> None:
self.event_handler_id = event_handler_id
self._queue = Queue[EventBase | None]()
self._stop_event = threading.Event()
asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
super().__init__()
def stop(self, *args, **kwargs):
self._stop_event.set()
self._queue.put(None)
def dispatch(self, event: EventBase) -> None:
self._queue.put(event)
async def _dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""
while not stop_event.is_set():
try:
event = self._queue.get(block=False)
if not event: # Probably stopping
continue
# Leave the payloads as live pydantic models
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
except Empty:
await asyncio.sleep(0.1)
pass
except asyncio.CancelledError as e:
raise e # Raise a proper error

View File

@ -1,11 +1,13 @@
"""Initialization file for model install service package.""" """Initialization file for model install service package."""
from .model_install_base import ( from .model_install_base import (
ModelInstallServiceBase,
)
from .model_install_common import (
HFModelSource, HFModelSource,
InstallStatus, InstallStatus,
LocalModelSource, LocalModelSource,
ModelInstallJob, ModelInstallJob,
ModelInstallServiceBase,
ModelSource, ModelSource,
UnknownInstallJobException, UnknownInstallJobException,
URLModelSource, URLModelSource,

View File

@ -1,242 +1,17 @@
# Copyright 2023 Lincoln D. Stein and the InvokeAI development team # Copyright 2023 Lincoln D. Stein and the InvokeAI development team
"""Baseclass definitions for the model installer.""" """Baseclass definitions for the model installer."""
import re
import traceback
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob from invokeai.app.services.download import DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant from invokeai.backend.model_manager import AnyModelConfig
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
class InstallStatus(str, Enum):
"""State of an install job running in the background."""
WAITING = "waiting" # waiting to be dequeued
DOWNLOADING = "downloading" # downloading of model files in process
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
RUNNING = "running" # being processed
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
CANCELLED = "cancelled" # terminated with an error message
class ModelInstallPart(BaseModel):
url: AnyHttpUrl
path: Path
bytes: int = 0
total_bytes: int = 0
class UnknownInstallJobException(Exception):
"""Raised when the status of an unknown job is requested."""
class StringLikeSource(BaseModel):
"""
Base class for model sources, implements functions that lets the source be sorted and indexed.
These shenanigans let this stuff work:
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
mydict = {source1: 'model 1'}
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
assert source1 == source2
assert source1 == 'C:/users/mort/foo.safetensors'
"""
def __hash__(self) -> int:
"""Return hash of the path field, for indexing."""
return hash(str(self))
def __lt__(self, other: object) -> int:
"""Return comparison of the stringified version, for sorting."""
return str(self) < str(other)
def __eq__(self, other: object) -> bool:
"""Return equality on the stringified version."""
if isinstance(other, Path):
return str(self) == other.as_posix()
else:
return str(self) == str(other)
class LocalModelSource(StringLikeSource):
"""A local file or directory path."""
path: str | Path
inplace: Optional[bool] = False
type: Literal["local"] = "local"
# these methods allow the source to be used in a string-like way,
# for example as an index into a dict
def __str__(self) -> str:
"""Return string version of path when string rep needed."""
return Path(self.path).as_posix()
class HFModelSource(StringLikeSource):
"""
A HuggingFace repo_id with optional variant, sub-folder and access token.
Note that the variant option, if not provided to the constructor, will default to fp16, which is
what people (almost) always want.
"""
repo_id: str
variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16
subfolder: Optional[Path] = None
access_token: Optional[str] = None
type: Literal["hf"] = "hf"
@field_validator("repo_id")
@classmethod
def proper_repo_id(cls, v: str) -> str: # noqa D102
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
raise ValueError(f"{v}: invalid repo_id format")
return v
def __str__(self) -> str:
"""Return string version of repoid when string rep needed."""
base: str = self.repo_id
if self.variant:
base += f":{self.variant or ''}"
if self.subfolder:
base += f":{self.subfolder}"
return base
class URLModelSource(StringLikeSource):
"""A generic URL point to a checkpoint file."""
url: AnyHttpUrl
access_token: Optional[str] = None
type: Literal["url"] = "url"
def __str__(self) -> str:
"""Return string version of the url when string rep needed."""
return str(self.url)
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
MODEL_SOURCE_TO_TYPE_MAP = {
URLModelSource: ModelSourceType.Url,
HFModelSource: ModelSourceType.HFRepoID,
LocalModelSource: ModelSourceType.Path,
}
class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request."""
id: int = Field(description="Unique ID for this job")
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
config_in: Dict[str, Any] = Field(
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
)
config_out: Optional[AnyModelConfig] = Field(
default=None, description="After successful installation, this will hold the configuration object."
)
inplace: bool = Field(
default=False, description="Leave model in its current location; otherwise install under models directory"
)
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
bytes: int = Field(
default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)"
)
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
source_metadata: Optional[AnyModelRepoMetadata] = Field(
default=None, description="Metadata provided by the model source"
)
error: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the text of the exception"
)
error_traceback: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the exception traceback"
)
# internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_download_job: Optional[MultiFileDownloadJob] = PrivateAttr(default=None)
_exception: Optional[Exception] = PrivateAttr(default=None)
def set_error(self, e: Exception) -> None:
"""Record the error and traceback from an exception."""
self._exception = e
self.error = str(e)
self.error_traceback = self._format_error(e)
self.status = InstallStatus.ERROR
self.error_reason = self._exception.__class__.__name__ if self._exception else None
def cancel(self) -> None:
"""Call to cancel the job."""
self.status = InstallStatus.CANCELLED
@property
def error_type(self) -> Optional[str]:
"""Class name of the exception that led to status==ERROR."""
return self._exception.__class__.__name__ if self._exception else None
def _format_error(self, exception: Exception) -> str:
"""Error traceback."""
return "".join(traceback.format_exception(exception))
@property
def cancelled(self) -> bool:
"""Set status to CANCELLED."""
return self.status == InstallStatus.CANCELLED
@property
def errored(self) -> bool:
"""Return true if job has errored."""
return self.status == InstallStatus.ERROR
@property
def waiting(self) -> bool:
"""Return true if job is waiting to run."""
return self.status == InstallStatus.WAITING
@property
def downloading(self) -> bool:
"""Return true if job is downloading."""
return self.status == InstallStatus.DOWNLOADING
@property
def downloads_done(self) -> bool:
"""Return true if job's downloads ae done."""
return self.status == InstallStatus.DOWNLOADS_DONE
@property
def running(self) -> bool:
"""Return true if job is running."""
return self.status == InstallStatus.RUNNING
@property
def complete(self) -> bool:
"""Return true if job completed without errors."""
return self.status == InstallStatus.COMPLETED
@property
def in_terminal_state(self) -> bool:
"""Return true if job is in a terminal state."""
return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED]
class ModelInstallServiceBase(ABC): class ModelInstallServiceBase(ABC):
@ -280,7 +55,7 @@ class ModelInstallServiceBase(ABC):
@property @property
@abstractmethod @abstractmethod
def event_bus(self) -> Optional[EventServiceBase]: def event_bus(self) -> Optional["EventServiceBase"]:
"""Return the event service base object associated with the installer.""" """Return the event service base object associated with the installer."""
@abstractmethod @abstractmethod

View File

@ -0,0 +1,227 @@
import re
import traceback
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Set, Union
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
class InstallStatus(str, Enum):
"""State of an install job running in the background."""
WAITING = "waiting" # waiting to be dequeued
DOWNLOADING = "downloading" # downloading of model files in process
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
RUNNING = "running" # being processed
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
CANCELLED = "cancelled" # terminated with an error message
class UnknownInstallJobException(Exception):
"""Raised when the status of an unknown job is requested."""
class StringLikeSource(BaseModel):
"""
Base class for model sources, implements functions that lets the source be sorted and indexed.
These shenanigans let this stuff work:
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
mydict = {source1: 'model 1'}
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
assert source1 == source2
assert source1 == 'C:/users/mort/foo.safetensors'
"""
def __hash__(self) -> int:
"""Return hash of the path field, for indexing."""
return hash(str(self))
def __lt__(self, other: object) -> int:
"""Return comparison of the stringified version, for sorting."""
return str(self) < str(other)
def __eq__(self, other: object) -> bool:
"""Return equality on the stringified version."""
if isinstance(other, Path):
return str(self) == other.as_posix()
else:
return str(self) == str(other)
class LocalModelSource(StringLikeSource):
"""A local file or directory path."""
path: str | Path
inplace: Optional[bool] = False
type: Literal["local"] = "local"
# these methods allow the source to be used in a string-like way,
# for example as an index into a dict
def __str__(self) -> str:
"""Return string version of path when string rep needed."""
return Path(self.path).as_posix()
class HFModelSource(StringLikeSource):
"""
A HuggingFace repo_id with optional variant, sub-folder and access token.
Note that the variant option, if not provided to the constructor, will default to fp16, which is
what people (almost) always want.
"""
repo_id: str
variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16
subfolder: Optional[Path] = None
access_token: Optional[str] = None
type: Literal["hf"] = "hf"
@field_validator("repo_id")
@classmethod
def proper_repo_id(cls, v: str) -> str: # noqa D102
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
raise ValueError(f"{v}: invalid repo_id format")
return v
def __str__(self) -> str:
"""Return string version of repoid when string rep needed."""
base: str = self.repo_id
if self.variant:
base += f":{self.variant or ''}"
if self.subfolder:
base += f":{self.subfolder}"
return base
class URLModelSource(StringLikeSource):
"""A generic URL point to a checkpoint file."""
url: AnyHttpUrl
access_token: Optional[str] = None
type: Literal["url"] = "url"
def __str__(self) -> str:
"""Return string version of the url when string rep needed."""
return str(self.url)
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
MODEL_SOURCE_TO_TYPE_MAP = {
URLModelSource: ModelSourceType.Url,
HFModelSource: ModelSourceType.HFRepoID,
LocalModelSource: ModelSourceType.Path,
}
class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request."""
id: int = Field(description="Unique ID for this job")
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
config_in: Dict[str, Any] = Field(
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
)
config_out: Optional[AnyModelConfig] = Field(
default=None, description="After successful installation, this will hold the configuration object."
)
inplace: bool = Field(
default=False, description="Leave model in its current location; otherwise install under models directory"
)
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
bytes: int = Field(
default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)"
)
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
source_metadata: Optional[AnyModelRepoMetadata] = Field(
default=None, description="Metadata provided by the model source"
)
download_parts: Set[DownloadJob] = Field(
default_factory=set, description="Download jobs contributing to this install"
)
error: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the text of the exception"
)
error_traceback: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the exception traceback"
)
# internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_download_job: Optional[MultiFileDownloadJob] = PrivateAttr(default=None)
_exception: Optional[Exception] = PrivateAttr(default=None)
def set_error(self, e: Exception) -> None:
"""Record the error and traceback from an exception."""
self._exception = e
self.error = str(e)
self.error_traceback = self._format_error(e)
self.status = InstallStatus.ERROR
self.error_reason = self._exception.__class__.__name__ if self._exception else None
def cancel(self) -> None:
"""Call to cancel the job."""
self.status = InstallStatus.CANCELLED
@property
def error_type(self) -> Optional[str]:
"""Class name of the exception that led to status==ERROR."""
return self._exception.__class__.__name__ if self._exception else None
def _format_error(self, exception: Exception) -> str:
"""Error traceback."""
return "".join(traceback.format_exception(exception))
@property
def cancelled(self) -> bool:
"""Set status to CANCELLED."""
return self.status == InstallStatus.CANCELLED
@property
def errored(self) -> bool:
"""Return true if job has errored."""
return self.status == InstallStatus.ERROR
@property
def waiting(self) -> bool:
"""Return true if job is waiting to run."""
return self.status == InstallStatus.WAITING
@property
def downloading(self) -> bool:
"""Return true if job is downloading."""
return self.status == InstallStatus.DOWNLOADING
@property
def downloads_done(self) -> bool:
"""Return true if job's downloads ae done."""
return self.status == InstallStatus.DOWNLOADS_DONE
@property
def running(self) -> bool:
"""Return true if job is running."""
return self.status == InstallStatus.RUNNING
@property
def complete(self) -> bool:
"""Return true if job completed without errors."""
return self.status == InstallStatus.COMPLETED
@property
def in_terminal_state(self) -> bool:
"""Return true if job is in a terminal state."""
return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED]

View File

@ -9,7 +9,7 @@ from pathlib import Path
from queue import Empty, Queue from queue import Empty, Queue
from shutil import copyfile, copytree, move, rmtree from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Tuple, Type, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
import torch import torch
import yaml import yaml
@ -21,6 +21,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
@ -45,13 +46,12 @@ from invokeai.backend.util.catch_sigint import catch_sigint
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.util import slugify from invokeai.backend.util.util import slugify
from .model_install_base import ( from .model_install_common import (
MODEL_SOURCE_TO_TYPE_MAP, MODEL_SOURCE_TO_TYPE_MAP,
HFModelSource, HFModelSource,
InstallStatus, InstallStatus,
LocalModelSource, LocalModelSource,
ModelInstallJob, ModelInstallJob,
ModelInstallServiceBase,
ModelSource, ModelSource,
StringLikeSource, StringLikeSource,
URLModelSource, URLModelSource,
@ -59,6 +59,9 @@ from .model_install_base import (
TMPDIR_PREFIX = "tmpinstall_" TMPDIR_PREFIX = "tmpinstall_"
if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
class ModelInstallService(ModelInstallServiceBase): class ModelInstallService(ModelInstallServiceBase):
"""class for InvokeAI model installation.""" """class for InvokeAI model installation."""
@ -68,7 +71,7 @@ class ModelInstallService(ModelInstallServiceBase):
app_config: InvokeAIAppConfig, app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase, record_store: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase, download_queue: DownloadQueueServiceBase,
event_bus: Optional[EventServiceBase] = None, event_bus: Optional["EventServiceBase"] = None,
session: Optional[Session] = None, session: Optional[Session] = None,
): ):
""" """
@ -104,7 +107,7 @@ class ModelInstallService(ModelInstallServiceBase):
return self._record_store return self._record_store
@property @property
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102 def event_bus(self) -> Optional["EventServiceBase"]: # noqa D102
return self._event_bus return self._event_bus
# make the invoker optional here because we don't need it and it # make the invoker optional here because we don't need it and it
@ -825,6 +828,7 @@ class ModelInstallService(ModelInstallServiceBase):
else: else:
# update sizes # update sizes
install_job.bytes = sum(x.bytes for x in download_job.download_parts) install_job.bytes = sum(x.bytes for x in download_job.download_parts)
install_job.download_parts = download_job.download_parts
self._signal_job_downloading(install_job) self._signal_job_downloading(install_job)
def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None: def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None:
@ -864,36 +868,20 @@ class ModelInstallService(ModelInstallServiceBase):
job.status = InstallStatus.RUNNING job.status = InstallStatus.RUNNING
self._logger.info(f"Model install started: {job.source}") self._logger.info(f"Model install started: {job.source}")
if self._event_bus: if self._event_bus:
self._event_bus.emit_model_install_running(str(job.source)) self._event_bus.emit_model_install_started(job)
def _signal_job_downloading(self, job: ModelInstallJob) -> None: def _signal_job_downloading(self, job: ModelInstallJob) -> None:
if self._event_bus: if self._event_bus:
assert job._download_job is not None assert job._download_job is not None
parts: List[Dict[str, str | int]] = [
{
"url": str(x.source),
"local_path": str(x.download_path),
"bytes": x.bytes,
"total_bytes": x.total_bytes,
}
for x in job._download_job.download_parts
]
assert job.bytes is not None assert job.bytes is not None
assert job.total_bytes is not None assert job.total_bytes is not None
self._event_bus.emit_model_install_downloading( self._event_bus.emit_model_install_download_progress(job)
str(job.source),
local_path=job.local_path.as_posix(),
parts=parts,
bytes=sum(x["bytes"] for x in parts),
total_bytes=sum(x["total_bytes"] for x in parts),
id=job.id,
)
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None: def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.DOWNLOADS_DONE job.status = InstallStatus.DOWNLOADS_DONE
self._logger.info(f"Model download complete: {job.source}") self._logger.info(f"Model download complete: {job.source}")
if self._event_bus: if self._event_bus:
self._event_bus.emit_model_install_downloads_done(str(job.source)) self._event_bus.emit_model_install_downloads_complete(job)
def _signal_job_completed(self, job: ModelInstallJob) -> None: def _signal_job_completed(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.COMPLETED job.status = InstallStatus.COMPLETED
@ -903,24 +891,19 @@ class ModelInstallService(ModelInstallServiceBase):
if self._event_bus: if self._event_bus:
assert job.local_path is not None assert job.local_path is not None
assert job.config_out is not None assert job.config_out is not None
key = job.config_out.key self._event_bus.emit_model_install_complete(job)
self._event_bus.emit_model_install_completed(
source=str(job.source), key=key, id=job.id, total_bytes=job.bytes
)
def _signal_job_errored(self, job: ModelInstallJob) -> None: def _signal_job_errored(self, job: ModelInstallJob) -> None:
self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}") self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}")
if self._event_bus: if self._event_bus:
error_type = job.error_type assert job.error_type is not None
error = job.error assert job.error is not None
assert error_type is not None self._event_bus.emit_model_install_error(job)
assert error is not None
self._event_bus.emit_model_install_error(str(job.source), error_type, error, id=job.id)
def _signal_job_cancelled(self, job: ModelInstallJob) -> None: def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
self._logger.info(f"Model install canceled: {job.source}") self._logger.info(f"Model install canceled: {job.source}")
if self._event_bus: if self._event_bus:
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id) self._event_bus.emit_model_install_cancelled(job)
@staticmethod @staticmethod
def get_fetcher_from_url(url: str) -> Type[ModelMetadataFetchBase]: def get_fetcher_from_url(url: str) -> Type[ModelMetadataFetchBase]:

View File

@ -7,7 +7,6 @@ from typing import Callable, Dict, Optional
from torch import Tensor from torch import Tensor
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import LoadedModel from invokeai.backend.model_manager.load import LoadedModel
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
@ -18,18 +17,12 @@ class ModelLoadServiceBase(ABC):
"""Wrapper around AnyModelLoader.""" """Wrapper around AnyModelLoader."""
@abstractmethod @abstractmethod
def load_model( def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
""" """
Given a model's configuration, load it and return the LoadedModel object. Given a model's configuration, load it and return the LoadedModel object.
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) :param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
:param submodel: For main (pipeline models), the submodel to fetch. :param submodel: For main (pipeline models), the submodel to fetch.
:param context_data: Invocation context data used for event reporting
""" """
@property @property

View File

@ -11,7 +11,6 @@ from torch import load as torch_load
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import ( from invokeai.backend.model_manager.load import (
LoadedModel, LoadedModel,
@ -59,25 +58,18 @@ class ModelLoadService(ModelLoadServiceBase):
"""Return the checkpoint convert cache used by this loader.""" """Return the checkpoint convert cache used by this loader."""
return self._convert_cache return self._convert_cache
def load_model( def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
""" """
Given a model's configuration, load it and return the LoadedModel object. Given a model's configuration, load it and return the LoadedModel object.
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) :param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
:param submodel: For main (pipeline models), the submodel to fetch. :param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting
""" """
if context_data:
self._emit_load_event( # We don't have an invoker during testing
context_data=context_data, # TODO(psyche): Mock this method on the invoker in the tests
model_config=model_config, if hasattr(self, "_invoker"):
submodel_type=submodel_type, self._invoker.services.events.emit_model_load_started(model_config, submodel_type)
)
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
loaded_model: LoadedModel = implementation( loaded_model: LoadedModel = implementation(
@ -87,13 +79,9 @@ class ModelLoadService(ModelLoadServiceBase):
convert_cache=self._convert_cache, convert_cache=self._convert_cache,
).load_model(model_config, submodel_type) ).load_model(model_config, submodel_type)
if context_data: if hasattr(self, "_invoker"):
self._emit_load_event( self._invoker.services.events.emit_model_load_complete(model_config, submodel_type)
context_data=context_data,
model_config=model_config,
submodel_type=submodel_type,
loaded=True,
)
return loaded_model return loaded_model
def load_model_from_path( def load_model_from_path(
@ -150,32 +138,3 @@ class ModelLoadService(ModelLoadServiceBase):
raw_model = loader(model_path) raw_model = loader(model_path)
ram_cache.put(key=cache_key, model=raw_model) ram_cache.put(key=cache_key, model=raw_model)
return LoadedModel(_locker=ram_cache.get(key=cache_key)) return LoadedModel(_locker=ram_cache.get(key=cache_key))
def _emit_load_event(
self,
context_data: InvocationContextData,
model_config: AnyModelConfig,
loaded: Optional[bool] = False,
submodel_type: Optional[SubModelType] = None,
) -> None:
if not self._invoker:
return
if not loaded:
self._invoker.services.events.emit_model_load_started(
queue_id=context_data.queue_item.queue_id,
queue_item_id=context_data.queue_item.item_id,
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
model_config=model_config,
submodel_type=submodel_type,
)
else:
self._invoker.services.events.emit_model_load_completed(
queue_id=context_data.queue_item.queue_id,
queue_item_id=context_data.queue_item.item_id,
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
model_config=model_config,
submodel_type=submodel_type,
)

View File

@ -1,6 +1,49 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from threading import Event
from typing import Optional, Protocol
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from invokeai.app.util.profiler import Profiler
class SessionRunnerBase(ABC):
"""
Base class for session runner.
"""
@abstractmethod
def start(self, services: InvocationServices, cancel_event: Event, profiler: Optional[Profiler] = None) -> None:
"""Starts the session runner.
Args:
services: The invocation services.
cancel_event: The cancel event.
profiler: The profiler to use for session profiling via cProfile. Omit to disable profiling. Basic session
stats will be still be recorded and logged when profiling is disabled.
"""
pass
@abstractmethod
def run(self, queue_item: SessionQueueItem) -> None:
"""Runs a session.
Args:
queue_item: The session to run.
"""
pass
@abstractmethod
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
"""Run a single node in the graph.
Args:
invocation: The invocation to run.
queue_item: The session queue item.
"""
pass
class SessionProcessorBase(ABC): class SessionProcessorBase(ABC):
@ -26,3 +69,85 @@ class SessionProcessorBase(ABC):
def get_status(self) -> SessionProcessorStatus: def get_status(self) -> SessionProcessorStatus:
"""Gets the status of the session processor""" """Gets the status of the session processor"""
pass pass
class OnBeforeRunNode(Protocol):
def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
"""Callback to run before executing a node.
Args:
invocation: The invocation that will be executed.
queue_item: The session queue item.
"""
...
class OnAfterRunNode(Protocol):
def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput) -> None:
"""Callback to run before executing a node.
Args:
invocation: The invocation that was executed.
queue_item: The session queue item.
"""
...
class OnNodeError(Protocol):
def __call__(
self,
invocation: BaseInvocation,
queue_item: SessionQueueItem,
error_type: str,
error_message: str,
error_traceback: str,
) -> None:
"""Callback to run when a node has an error.
Args:
invocation: The invocation that errored.
queue_item: The session queue item.
error_type: The type of error, e.g. "ValueError".
error_message: The error message, e.g. "Invalid value".
error_traceback: The stringified error traceback.
"""
...
class OnBeforeRunSession(Protocol):
def __call__(self, queue_item: SessionQueueItem) -> None:
"""Callback to run before executing a session.
Args:
queue_item: The session queue item.
"""
...
class OnAfterRunSession(Protocol):
def __call__(self, queue_item: SessionQueueItem) -> None:
"""Callback to run after executing a session.
Args:
queue_item: The session queue item.
"""
...
class OnNonFatalProcessorError(Protocol):
def __call__(
self,
queue_item: Optional[SessionQueueItem],
error_type: str,
error_message: str,
error_traceback: str,
) -> None:
"""Callback to run when a non-fatal error occurs in the processor.
Args:
queue_item: The session queue item, if one was being executed when the error occurred.
error_type: The type of error, e.g. "ValueError".
error_message: The error message, e.g. "Invalid value".
error_traceback: The stringified error traceback.
"""
...

View File

@ -4,24 +4,325 @@ from threading import BoundedSemaphore, Thread
from threading import Event as ThreadEvent from threading import Event as ThreadEvent
from typing import Optional from typing import Optional
from fastapi_events.handlers.local import local_handler from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from fastapi_events.typing import Event as FastAPIEvent from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
from invokeai.app.invocations.baseinvocation import BaseInvocation FastAPIEvent,
from invokeai.app.services.events.events_base import EventServiceBase QueueClearedEvent,
QueueItemStatusChangedEvent,
register_events,
)
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
from invokeai.app.services.session_processor.session_processor_base import (
OnAfterRunNode,
OnAfterRunSession,
OnBeforeRunNode,
OnBeforeRunSession,
OnNodeError,
OnNonFatalProcessorError,
)
from invokeai.app.services.session_processor.session_processor_common import CanceledException from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem, SessionQueueItemNotFoundError
from invokeai.app.services.shared.graph import NodeInputError
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
from invokeai.app.util.profiler import Profiler from invokeai.app.util.profiler import Profiler
from ..invoker import Invoker from ..invoker import Invoker
from .session_processor_base import SessionProcessorBase from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase
from .session_processor_common import SessionProcessorStatus from .session_processor_common import SessionProcessorStatus
class DefaultSessionRunner(SessionRunnerBase):
"""Processes a single session's invocations."""
def __init__(
self,
on_before_run_session_callbacks: Optional[list[OnBeforeRunSession]] = None,
on_before_run_node_callbacks: Optional[list[OnBeforeRunNode]] = None,
on_after_run_node_callbacks: Optional[list[OnAfterRunNode]] = None,
on_node_error_callbacks: Optional[list[OnNodeError]] = None,
on_after_run_session_callbacks: Optional[list[OnAfterRunSession]] = None,
):
"""
Args:
on_before_run_session_callbacks: Callbacks to run before the session starts.
on_before_run_node_callbacks: Callbacks to run before each node starts.
on_after_run_node_callbacks: Callbacks to run after each node completes.
on_node_error_callbacks: Callbacks to run when a node errors.
on_after_run_session_callbacks: Callbacks to run after the session completes.
"""
self._on_before_run_session_callbacks = on_before_run_session_callbacks or []
self._on_before_run_node_callbacks = on_before_run_node_callbacks or []
self._on_after_run_node_callbacks = on_after_run_node_callbacks or []
self._on_node_error_callbacks = on_node_error_callbacks or []
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
self._services = services
self._cancel_event = cancel_event
self._profiler = profiler
def _is_canceled(self) -> bool:
"""Check if the cancel event is set. This is also passed to the invocation context builder and called during
denoising to check if the session has been canceled."""
return self._cancel_event.is_set()
def run(self, queue_item: SessionQueueItem):
# Exceptions raised outside `run_node` are handled by the processor. There is no need to catch them here.
self._on_before_run_session(queue_item=queue_item)
# Loop over invocations until the session is complete or canceled
while True:
try:
invocation = queue_item.session.next()
# Anything other than a `NodeInputError` is handled as a processor error
except NodeInputError as e:
error_type = e.__class__.__name__
error_message = str(e)
error_traceback = traceback.format_exc()
self._on_node_error(
invocation=e.node,
queue_item=queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
break
if invocation is None or self._is_canceled():
break
self.run_node(invocation, queue_item)
# The session is complete if all invocations have been run or there is an error on the session.
# At this time, the queue item may be canceled, but the object itself here won't be updated yet. We must
# use the cancel event to check if the session is canceled.
if (
queue_item.session.is_complete()
or self._is_canceled()
or queue_item.status in ["failed", "canceled", "completed"]
):
break
self._on_after_run_session(queue_item=queue_item)
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
try:
# Any unhandled exception in this scope is an invocation error & will fail the graph
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
self._on_before_run_node(invocation, queue_item)
data = InvocationContextData(
invocation=invocation,
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
queue_item=queue_item,
)
context = build_invocation_context(
data=data,
services=self._services,
is_canceled=self._is_canceled,
)
# Invoke the node
output = invocation.invoke_internal(context=context, services=self._services)
# Save output and history
queue_item.session.complete(invocation.id, output)
self._on_after_run_node(invocation, queue_item, output)
except KeyboardInterrupt:
# TODO(psyche): This is expected to be caught in the main thread. Do we need to catch this here?
pass
except CanceledException:
# A CanceledException is raised during the denoising step callback if the cancel event is set. We don't need
# to do any handling here, and no error should be set - just pass and the cancellation will be handled
# correctly in the next iteration of the session runner loop.
#
# See the comment in the processor's `_on_queue_item_status_changed()` method for more details on how we
# handle cancellation.
pass
except Exception as e:
error_type = e.__class__.__name__
error_message = str(e)
error_traceback = traceback.format_exc()
self._on_node_error(
invocation=invocation,
queue_item=queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
"""Called before a session is run.
- Start the profiler if profiling is enabled.
- Run any callbacks registered for this event.
"""
self._services.logger.debug(
f"On before run session: queue item {queue_item.item_id}, session {queue_item.session_id}"
)
# If profiling is enabled, start the profiler
if self._profiler is not None:
self._profiler.start(profile_id=queue_item.session_id)
for callback in self._on_before_run_session_callbacks:
callback(queue_item=queue_item)
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
"""Called after a session is run.
- Stop the profiler if profiling is enabled.
- Update the queue item's session object in the database.
- If not already canceled or failed, complete the queue item.
- Log and reset performance statistics.
- Run any callbacks registered for this event.
"""
self._services.logger.debug(
f"On after run session: queue item {queue_item.item_id}, session {queue_item.session_id}"
)
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler is not None:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
self._services.performance_statistics.dump_stats(
graph_execution_state_id=queue_item.session.id, output_path=stats_path
)
try:
# Update the queue item with the completed session. If the queue item has been removed from the queue,
# we'll get a SessionQueueItemNotFoundError and we can ignore it. This can happen if the queue is cleared
# while the session is running.
queue_item = self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
# The queue item may have been canceled or failed while the session was running. We should only complete it
# if it is not already canceled or failed.
if queue_item.status not in ["canceled", "failed"]:
queue_item = self._services.session_queue.complete_queue_item(queue_item.item_id)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._services.performance_statistics.log_stats(queue_item.session.id)
self._services.performance_statistics.reset_stats()
for callback in self._on_after_run_session_callbacks:
callback(queue_item=queue_item)
except SessionQueueItemNotFoundError:
pass
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Called before a node is run.
- Emits an invocation started event.
- Run any callbacks registered for this event.
"""
self._services.logger.debug(
f"On before run node: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})"
)
# Send starting event
self._services.events.emit_invocation_started(queue_item=queue_item, invocation=invocation)
for callback in self._on_before_run_node_callbacks:
callback(invocation=invocation, queue_item=queue_item)
def _on_after_run_node(
self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput
):
"""Called after a node is run.
- Emits an invocation complete event.
- Run any callbacks registered for this event.
"""
self._services.logger.debug(
f"On after run node: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})"
)
# Send complete event on successful runs
self._services.events.emit_invocation_complete(invocation=invocation, queue_item=queue_item, output=output)
for callback in self._on_after_run_node_callbacks:
callback(invocation=invocation, queue_item=queue_item, output=output)
def _on_node_error(
self,
invocation: BaseInvocation,
queue_item: SessionQueueItem,
error_type: str,
error_message: str,
error_traceback: str,
):
"""Called when a node errors. Node errors may occur when running or preparing the node..
- Set the node error on the session object.
- Log the error.
- Fail the queue item.
- Emits an invocation error event.
- Run any callbacks registered for this event.
"""
self._services.logger.debug(
f"On node error: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})"
)
# Node errors do not get the full traceback. Only the queue item gets the full traceback.
node_error = f"{error_type}: {error_message}"
queue_item.session.set_node_error(invocation.id, node_error)
self._services.logger.error(
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}): {error_message}"
)
self._services.logger.error(error_traceback)
# Fail the queue item
queue_item = self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
queue_item = self._services.session_queue.fail_queue_item(
queue_item.item_id, error_type, error_message, error_traceback
)
# Send error event
self._services.events.emit_invocation_error(
queue_item=queue_item,
invocation=invocation,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
for callback in self._on_node_error_callbacks:
callback(
invocation=invocation,
queue_item=queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
class DefaultSessionProcessor(SessionProcessorBase): class DefaultSessionProcessor(SessionProcessorBase):
def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None: def __init__(
self,
session_runner: Optional[SessionRunnerBase] = None,
on_non_fatal_processor_error_callbacks: Optional[list[OnNonFatalProcessorError]] = None,
thread_limit: int = 1,
polling_interval: int = 1,
) -> None:
super().__init__()
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
self._on_non_fatal_processor_error_callbacks = on_non_fatal_processor_error_callbacks or []
self._thread_limit = thread_limit
self._polling_interval = polling_interval
def start(self, invoker: Invoker) -> None:
self._invoker: Invoker = invoker self._invoker: Invoker = invoker
self._queue_item: Optional[SessionQueueItem] = None self._queue_item: Optional[SessionQueueItem] = None
self._invocation: Optional[BaseInvocation] = None self._invocation: Optional[BaseInvocation] = None
@ -31,11 +332,11 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._poll_now_event = ThreadEvent() self._poll_now_event = ThreadEvent()
self._cancel_event = ThreadEvent() self._cancel_event = ThreadEvent()
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event) register_events(QueueClearedEvent, self._on_queue_cleared)
register_events(BatchEnqueuedEvent, self._on_batch_enqueued)
register_events(QueueItemStatusChangedEvent, self._on_queue_item_status_changed)
self._thread_limit = thread_limit self._thread_semaphore = BoundedSemaphore(self._thread_limit)
self._thread_semaphore = BoundedSemaphore(thread_limit)
self._polling_interval = polling_interval
# If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally, # If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally,
# the profiler will create a new profile for each session. # the profiler will create a new profile for each session.
@ -49,6 +350,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
else None else None
) )
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler)
self._thread = Thread( self._thread = Thread(
name="session_processor", name="session_processor",
target=self._process, target=self._process,
@ -67,30 +369,25 @@ class DefaultSessionProcessor(SessionProcessorBase):
def _poll_now(self) -> None: def _poll_now(self) -> None:
self._poll_now_event.set() self._poll_now_event.set()
async def _on_queue_event(self, event: FastAPIEvent) -> None: async def _on_queue_cleared(self, event: FastAPIEvent[QueueClearedEvent]) -> None:
event_name = event[1]["event"] if self._queue_item and self._queue_item.queue_id == event[1].queue_id:
self._cancel_event.set()
self._poll_now()
if ( async def _on_batch_enqueued(self, event: FastAPIEvent[BatchEnqueuedEvent]) -> None:
event_name == "session_canceled" self._poll_now()
and self._queue_item
and self._queue_item.item_id == event[1]["data"]["queue_item_id"] async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None:
): if self._queue_item and event[1].status in ["completed", "failed", "canceled"]:
self._cancel_event.set() # When the queue item is canceled via HTTP, the queue item status is set to `"canceled"` and this event is
self._poll_now() # emitted. We need to respond to this event and stop graph execution. This is done by setting the cancel
elif ( # event, which the session runner checks between invocations. If set, the session runner loop is broken.
event_name == "queue_cleared" #
and self._queue_item # Long-running nodes that cannot be interrupted easily present a challenge. `denoise_latents` is one such
and self._queue_item.queue_id == event[1]["data"]["queue_id"] # node, but it gets a step callback, called on each step of denoising. This callback checks if the queue item
): # is canceled, and if it is, raises a `CanceledException` to stop execution immediately.
self._cancel_event.set() if event[1].status == "canceled":
self._poll_now() self._cancel_event.set()
elif event_name == "batch_enqueued":
self._poll_now()
elif event_name == "queue_item_status_changed" and event[1]["data"]["queue_item"]["status"] in [
"completed",
"failed",
"canceled",
]:
self._poll_now() self._poll_now()
def resume(self) -> SessionProcessorStatus: def resume(self) -> SessionProcessorStatus:
@ -116,8 +413,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
resume_event: ThreadEvent, resume_event: ThreadEvent,
cancel_event: ThreadEvent, cancel_event: ThreadEvent,
): ):
# Outermost processor try block; any unhandled exception is a fatal processor error
try: try:
# Any unhandled exception in this block is a fatal processor error and will stop the processor.
self._thread_semaphore.acquire() self._thread_semaphore.acquire()
stop_event.clear() stop_event.clear()
resume_event.set() resume_event.set()
@ -125,8 +422,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
while not stop_event.is_set(): while not stop_event.is_set():
poll_now_event.clear() poll_now_event.clear()
# Middle processor try block; any unhandled exception is a non-fatal processor error
try: try:
# Any unhandled exception in this block is a nonfatal processor error and will be handled.
# If we are paused, wait for resume event # If we are paused, wait for resume event
resume_event.wait() resume_event.wait()
@ -142,159 +439,69 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}") self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
cancel_event.clear() cancel_event.clear()
# If profiling is enabled, start the profiler # Run the graph
if self._profiler is not None: self.session_runner.run(queue_item=self._queue_item)
self._profiler.start(profile_id=self._queue_item.session_id)
# Prepare invocations and take the first except Exception as e:
self._invocation = self._queue_item.session.next() error_type = e.__class__.__name__
error_message = str(e)
# Loop over invocations until the session is complete or canceled error_traceback = traceback.format_exc()
while self._invocation is not None and not cancel_event.is_set(): self._on_non_fatal_processor_error(
# get the source node id to provide to clients (the prepared node id is not as useful) queue_item=self._queue_item,
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id] error_type=error_type,
error_message=error_message,
# Send starting event error_traceback=error_traceback,
self._invoker.services.events.emit_invocation_started(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session_id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
)
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
try:
with self._invoker.services.performance_statistics.collect_stats(
self._invocation, self._queue_item.session.id
):
# Build invocation context (the node-facing API)
data = InvocationContextData(
invocation=self._invocation,
source_invocation_id=source_invocation_id,
queue_item=self._queue_item,
)
context = build_invocation_context(
data=data,
services=self._invoker.services,
cancel_event=self._cancel_event,
)
# Invoke the node
outputs = self._invocation.invoke_internal(
context=context, services=self._invoker.services
)
# Save outputs and history
self._queue_item.session.complete(self._invocation.id, outputs)
# Send complete event
self._invoker.services.events.emit_invocation_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
result=outputs.model_dump(),
)
except KeyboardInterrupt:
# TODO(MM2): Create an event for this
pass
except CanceledException:
# When the user cancels the graph, we first set the cancel event. The event is checked
# between invocations, in this loop. Some invocations are long-running, and we need to
# be able to cancel them mid-execution.
#
# For example, denoising is a long-running invocation with many steps. A step callback
# is executed after each step. This step callback checks if the canceled event is set,
# then raises a CanceledException to stop execution immediately.
#
# When we get a CanceledException, we don't need to do anything - just pass and let the
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
error = traceback.format_exc()
# Save error
self._queue_item.session.set_node_error(self._invocation.id, error)
self._invoker.services.logger.error(
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
)
self._invoker.services.logger.error(error)
# Send error event
self._invoker.services.events.emit_invocation_error(
queue_batch_id=self._queue_item.session_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
error_type=e.__class__.__name__,
error=error,
user_id=None,
project_id=None,
)
pass
# The session is complete if the all invocations are complete or there was an error
if self._queue_item.session.is_complete() or cancel_event.is_set():
# Send complete event
self._invoker.services.events.emit_graph_execution_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
)
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
self._invoker.services.performance_statistics.dump_stats(
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
self._invoker.services.performance_statistics.reset_stats()
# Set the invocation to None to prepare for the next session
self._invocation = None
else:
# Prepare the next invocation
self._invocation = self._queue_item.session.next()
else:
# The queue was empty, wait for next polling interval or event to try again
self._invoker.services.logger.debug("Waiting for next polling interval or event")
poll_now_event.wait(self._polling_interval)
continue
except Exception:
# Non-fatal error in processor
self._invoker.services.logger.error(
f"Non-fatal error in session processor:\n{traceback.format_exc()}"
) )
# Cancel the queue item # Wait for next polling interval or event to try again
if self._queue_item is not None:
self._invoker.services.session_queue.cancel_queue_item(
self._queue_item.item_id, error=traceback.format_exc()
)
# Reset the invocation to None to prepare for the next session
self._invocation = None
# Immediately poll for next queue item
poll_now_event.wait(self._polling_interval) poll_now_event.wait(self._polling_interval)
continue continue
except Exception: except Exception as e:
# Fatal error in processor, log and pass - we're done here # Fatal error in processor, log and pass - we're done here
self._invoker.services.logger.error(f"Fatal Error in session processor:\n{traceback.format_exc()}") error_type = e.__class__.__name__
error_message = str(e)
error_traceback = traceback.format_exc()
self._invoker.services.logger.error(f"Fatal Error in session processor {error_type}: {error_message}")
self._invoker.services.logger.error(error_traceback)
pass pass
finally: finally:
stop_event.clear() stop_event.clear()
poll_now_event.clear() poll_now_event.clear()
self._queue_item = None self._queue_item = None
self._thread_semaphore.release() self._thread_semaphore.release()
def _on_non_fatal_processor_error(
self,
queue_item: Optional[SessionQueueItem],
error_type: str,
error_message: str,
error_traceback: str,
) -> None:
"""Called when a non-fatal error occurs in the processor.
- Log the error.
- If a queue item is provided, update the queue item with the completed session & fail it.
- Run any callbacks registered for this event.
"""
self._invoker.services.logger.error(f"Non-fatal error in session processor {error_type}: {error_message}")
self._invoker.services.logger.error(error_traceback)
if queue_item is not None:
# Update the queue item with the completed session & fail it
queue_item = self._invoker.services.session_queue.set_queue_item_session(
queue_item.item_id, queue_item.session
)
queue_item = self._invoker.services.session_queue.fail_queue_item(
item_id=queue_item.item_id,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
for callback in self._on_non_fatal_processor_error_callbacks:
callback(
queue_item=queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)

View File

@ -16,6 +16,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueItemDTO, SessionQueueItemDTO,
SessionQueueStatus, SessionQueueStatus,
) )
from invokeai.app.services.shared.graph import GraphExecutionState
from invokeai.app.services.shared.pagination import CursorPaginatedResults from invokeai.app.services.shared.pagination import CursorPaginatedResults
@ -73,10 +74,22 @@ class SessionQueueBase(ABC):
pass pass
@abstractmethod @abstractmethod
def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem: def complete_queue_item(self, item_id: int) -> SessionQueueItem:
"""Completes a session queue item"""
pass
@abstractmethod
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
"""Cancels a session queue item""" """Cancels a session queue item"""
pass pass
@abstractmethod
def fail_queue_item(
self, item_id: int, error_type: str, error_message: str, error_traceback: str
) -> SessionQueueItem:
"""Fails a session queue item"""
pass
@abstractmethod @abstractmethod
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult: def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
"""Cancels all queue items with matching batch IDs""" """Cancels all queue items with matching batch IDs"""
@ -103,3 +116,8 @@ class SessionQueueBase(ABC):
def get_queue_item(self, item_id: int) -> SessionQueueItem: def get_queue_item(self, item_id: int) -> SessionQueueItem:
"""Gets a session queue item by ID""" """Gets a session queue item by ID"""
pass pass
@abstractmethod
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
"""Sets the session for a session queue item. Use this to update the session state."""
pass

View File

@ -3,7 +3,16 @@ import json
from itertools import chain, product from itertools import chain, product
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
from pydantic import BaseModel, ConfigDict, Field, StrictStr, TypeAdapter, field_validator, model_validator from pydantic import (
AliasChoices,
BaseModel,
ConfigDict,
Field,
StrictStr,
TypeAdapter,
field_validator,
model_validator,
)
from pydantic_core import to_jsonable_python from pydantic_core import to_jsonable_python
from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.invocations.baseinvocation import BaseInvocation
@ -189,7 +198,13 @@ class SessionQueueItemWithoutGraph(BaseModel):
session_id: str = Field( session_id: str = Field(
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed." description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
) )
error: Optional[str] = Field(default=None, description="The error message if this queue item errored") error_type: Optional[str] = Field(default=None, description="The error type if this queue item errored")
error_message: Optional[str] = Field(default=None, description="The error message if this queue item errored")
error_traceback: Optional[str] = Field(
default=None,
description="The error traceback if this queue item errored",
validation_alias=AliasChoices("error_traceback", "error"),
)
created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created") created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created")
updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated") updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated")
started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started") started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started")

View File

@ -2,10 +2,6 @@ import sqlite3
import threading import threading
from typing import Optional, Union, cast from typing import Optional, Union, cast
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
from invokeai.app.services.session_queue.session_queue_common import ( from invokeai.app.services.session_queue.session_queue_common import (
@ -27,6 +23,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
calc_session_count, calc_session_count,
prepare_values_to_insert, prepare_values_to_insert,
) )
from invokeai.app.services.shared.graph import GraphExecutionState
from invokeai.app.services.shared.pagination import CursorPaginatedResults from invokeai.app.services.shared.pagination import CursorPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
@ -41,7 +38,7 @@ class SqliteSessionQueue(SessionQueueBase):
self.__invoker = invoker self.__invoker = invoker
self._set_in_progress_to_canceled() self._set_in_progress_to_canceled()
prune_result = self.prune(DEFAULT_QUEUE_ID) prune_result = self.prune(DEFAULT_QUEUE_ID)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
if prune_result.deleted > 0: if prune_result.deleted > 0:
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items") self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
@ -51,52 +48,6 @@ class SqliteSessionQueue(SessionQueueBase):
self.__conn = db.conn self.__conn = db.conn
self.__cursor = self.__conn.cursor() self.__cursor = self.__conn.cursor()
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
return event[1]["event"] in match_in
async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent:
event_name = event[1]["event"]
# This was a match statement, but match is not supported on python 3.9
if event_name == "graph_execution_state_complete":
await self._handle_complete_event(event)
elif event_name == "invocation_error":
await self._handle_error_event(event)
elif event_name == "session_canceled":
await self._handle_cancel_event(event)
return event
async def _handle_complete_event(self, event: FastAPIEvent) -> None:
try:
item_id = event[1]["data"]["queue_item_id"]
# When a queue item has an error, we get an error event, then a completed event.
# Mark the queue item completed only if it isn't already marked completed, e.g.
# by a previously-handled error event.
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["completed", "failed", "canceled"]:
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed")
except SessionQueueItemNotFoundError:
return
async def _handle_error_event(self, event: FastAPIEvent) -> None:
try:
item_id = event[1]["data"]["queue_item_id"]
error = event[1]["data"]["error"]
queue_item = self.get_queue_item(item_id)
# always set to failed if have an error, even if previously the item was marked completed or canceled
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error)
except SessionQueueItemNotFoundError:
return
async def _handle_cancel_event(self, event: FastAPIEvent) -> None:
try:
item_id = event[1]["data"]["queue_item_id"]
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["completed", "failed", "canceled"]:
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled")
except SessionQueueItemNotFoundError:
return
def _set_in_progress_to_canceled(self) -> None: def _set_in_progress_to_canceled(self) -> None:
""" """
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue. Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
@ -271,17 +222,22 @@ class SqliteSessionQueue(SessionQueueBase):
return SessionQueueItem.queue_item_from_dict(dict(result)) return SessionQueueItem.queue_item_from_dict(dict(result))
def _set_queue_item_status( def _set_queue_item_status(
self, item_id: int, status: QUEUE_ITEM_STATUS, error: Optional[str] = None self,
item_id: int,
status: QUEUE_ITEM_STATUS,
error_type: Optional[str] = None,
error_message: Optional[str] = None,
error_traceback: Optional[str] = None,
) -> SessionQueueItem: ) -> SessionQueueItem:
try: try:
self.__lock.acquire() self.__lock.acquire()
self.__cursor.execute( self.__cursor.execute(
"""--sql """--sql
UPDATE session_queue UPDATE session_queue
SET status = ?, error = ? SET status = ?, error_type = ?, error_message = ?, error_traceback = ?
WHERE item_id = ? WHERE item_id = ?
""", """,
(status, error, item_id), (status, error_type, error_message, error_traceback, item_id),
) )
self.__conn.commit() self.__conn.commit()
except Exception: except Exception:
@ -292,11 +248,7 @@ class SqliteSessionQueue(SessionQueueBase):
queue_item = self.get_queue_item(item_id) queue_item = self.get_queue_item(item_id)
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id) batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_item.queue_id) queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
self.__invoker.services.events.emit_queue_item_status_changed( self.__invoker.services.events.emit_queue_item_status_changed(queue_item, batch_status, queue_status)
session_queue_item=queue_item,
batch_status=batch_status,
queue_status=queue_status,
)
return queue_item return queue_item
def is_empty(self, queue_id: str) -> IsEmptyResult: def is_empty(self, queue_id: str) -> IsEmptyResult:
@ -338,26 +290,6 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock.release() self.__lock.release()
return IsFullResult(is_full=is_full) return IsFullResult(is_full=is_full)
def delete_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self.get_queue_item(item_id=item_id)
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
DELETE FROM session_queue
WHERE
item_id = ?
""",
(item_id,),
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return queue_item
def clear(self, queue_id: str) -> ClearResult: def clear(self, queue_id: str) -> ClearResult:
try: try:
self.__lock.acquire() self.__lock.acquire()
@ -424,17 +356,28 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock.release() self.__lock.release()
return PruneResult(deleted=count) return PruneResult(deleted=count)
def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem: def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self.get_queue_item(item_id) queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
if queue_item.status not in ["canceled", "failed", "completed"]: return queue_item
status = "failed" if error is not None else "canceled"
queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error) # type: ignore [arg-type] # mypy seems to not narrow the Literals here def complete_queue_item(self, item_id: int) -> SessionQueueItem:
self.__invoker.services.events.emit_session_canceled( queue_item = self._set_queue_item_status(item_id=item_id, status="completed")
queue_item_id=queue_item.item_id, return queue_item
queue_id=queue_item.queue_id,
queue_batch_id=queue_item.batch_id, def fail_queue_item(
graph_execution_state_id=queue_item.session_id, self,
) item_id: int,
error_type: str,
error_message: str,
error_traceback: str,
) -> SessionQueueItem:
queue_item = self._set_queue_item_status(
item_id=item_id,
status="failed",
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
return queue_item return queue_item
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult: def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
@ -470,18 +413,10 @@ class SqliteSessionQueue(SessionQueueBase):
) )
self.__conn.commit() self.__conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids: if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self.__invoker.services.events.emit_session_canceled(
queue_item_id=current_queue_item.item_id,
queue_id=current_queue_item.queue_id,
queue_batch_id=current_queue_item.batch_id,
graph_execution_state_id=current_queue_item.session_id,
)
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id) batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id) queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed( self.__invoker.services.events.emit_queue_item_status_changed(
session_queue_item=current_queue_item, current_queue_item, batch_status, queue_status
batch_status=batch_status,
queue_status=queue_status,
) )
except Exception: except Exception:
self.__conn.rollback() self.__conn.rollback()
@ -521,18 +456,10 @@ class SqliteSessionQueue(SessionQueueBase):
) )
self.__conn.commit() self.__conn.commit()
if current_queue_item is not None and current_queue_item.queue_id == queue_id: if current_queue_item is not None and current_queue_item.queue_id == queue_id:
self.__invoker.services.events.emit_session_canceled(
queue_item_id=current_queue_item.item_id,
queue_id=current_queue_item.queue_id,
queue_batch_id=current_queue_item.batch_id,
graph_execution_state_id=current_queue_item.session_id,
)
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id) batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id) queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed( self.__invoker.services.events.emit_queue_item_status_changed(
session_queue_item=current_queue_item, current_queue_item, batch_status, queue_status
batch_status=batch_status,
queue_status=queue_status,
) )
except Exception: except Exception:
self.__conn.rollback() self.__conn.rollback()
@ -562,6 +489,29 @@ class SqliteSessionQueue(SessionQueueBase):
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}") raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
return SessionQueueItem.queue_item_from_dict(dict(result)) return SessionQueueItem.queue_item_from_dict(dict(result))
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
try:
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
# during execution.
session_json = session.model_dump_json(warnings=False, exclude_none=True)
self.__lock.acquire()
self.__cursor.execute(
"""--sql
UPDATE session_queue
SET session = ?
WHERE item_id = ?
""",
(session_json, item_id),
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return self.get_queue_item(item_id)
def list_queue_items( def list_queue_items(
self, self,
queue_id: str, queue_id: str,
@ -578,7 +528,9 @@ class SqliteSessionQueue(SessionQueueBase):
status, status,
priority, priority,
field_values, field_values,
error, error_type,
error_message,
error_traceback,
created_at, created_at,
updated_at, updated_at,
completed_at, completed_at,

View File

@ -8,6 +8,7 @@ import networkx as nx
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
GetJsonSchemaHandler, GetJsonSchemaHandler,
ValidationError,
field_validator, field_validator,
) )
from pydantic.fields import Field from pydantic.fields import Field
@ -190,6 +191,39 @@ class UnknownGraphValidationError(ValueError):
pass pass
class NodeInputError(ValueError):
"""Raised when a node fails preparation. This occurs when a node's inputs are being set from its incomers, but an
input fails validation.
Attributes:
node: The node that failed preparation. Note: only successfully set fields will be accurate. Review the error to
determine which field caused the failure.
"""
def __init__(self, node: BaseInvocation, e: ValidationError):
self.original_error = e
self.node = node
# When preparing a node, we set each input one-at-a-time. We may thus safely assume that the first error
# represents the first input that failed.
self.failed_input = loc_to_dot_sep(e.errors()[0]["loc"])
super().__init__(f"Node {node.id} has invalid incoming input for {self.failed_input}")
def loc_to_dot_sep(loc: tuple[Union[str, int], ...]) -> str:
"""Helper to pretty-print pydantic error locations as dot-separated strings.
Taken from https://docs.pydantic.dev/latest/errors/errors/#customize-error-messages
"""
path = ""
for i, x in enumerate(loc):
if isinstance(x, str):
if i > 0:
path += "."
path += x
else:
path += f"[{x}]"
return path
@invocation_output("iterate_output") @invocation_output("iterate_output")
class IterateInvocationOutput(BaseInvocationOutput): class IterateInvocationOutput(BaseInvocationOutput):
"""Used to connect iteration outputs. Will be expanded to a specific output.""" """Used to connect iteration outputs. Will be expanded to a specific output."""
@ -821,7 +855,10 @@ class GraphExecutionState(BaseModel):
# Get values from edges # Get values from edges
if next_node is not None: if next_node is not None:
self._prepare_inputs(next_node) try:
self._prepare_inputs(next_node)
except ValidationError as e:
raise NodeInputError(next_node, e)
# If next is still none, there's no next node, return None # If next is still none, there's no next node, return None
return next_node return next_node

View File

@ -1,4 +1,3 @@
import threading
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
@ -359,12 +358,11 @@ class ModelsInterface(InvocationContextInterface):
if isinstance(identifier, str): if isinstance(identifier, str):
model = self._services.model_manager.store.get_model(identifier) model = self._services.model_manager.store.get_model(identifier)
result: LoadedModel = self._services.model_manager.load.load_model(model, submodel_type, self._data) return self._services.model_manager.load.load_model(model, submodel_type)
else: else:
_submodel_type = submodel_type or identifier.submodel_type _submodel_type = submodel_type or identifier.submodel_type
model = self._services.model_manager.store.get_model(identifier.key) model = self._services.model_manager.store.get_model(identifier.key)
result = self._services.model_manager.load.load_model(model, _submodel_type, self._data) return self._services.model_manager.load.load_model(model, _submodel_type)
return result
def load_by_attrs( def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
@ -388,8 +386,7 @@ class ModelsInterface(InvocationContextInterface):
if len(configs) > 1: if len(configs) > 1:
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}") raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
result: LoadedModel = self._services.model_manager.load.load_model(configs[0], submodel_type, self._data) return self._services.model_manager.load.load_model(configs[0], submodel_type)
return result
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig: def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
"""Get a model's config. """Get a model's config.
@ -516,10 +513,10 @@ class ConfigInterface(InvocationContextInterface):
class UtilInterface(InvocationContextInterface): class UtilInterface(InvocationContextInterface):
def __init__( def __init__(
self, services: InvocationServices, data: InvocationContextData, cancel_event: threading.Event self, services: InvocationServices, data: InvocationContextData, is_canceled: Callable[[], bool]
) -> None: ) -> None:
super().__init__(services, data) super().__init__(services, data)
self._cancel_event = cancel_event self._is_canceled = is_canceled
def is_canceled(self) -> bool: def is_canceled(self) -> bool:
"""Checks if the current session has been canceled. """Checks if the current session has been canceled.
@ -527,7 +524,7 @@ class UtilInterface(InvocationContextInterface):
Returns: Returns:
True if the current session has been canceled, False if not. True if the current session has been canceled, False if not.
""" """
return self._cancel_event.is_set() return self._is_canceled()
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None: def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
""" """
@ -602,7 +599,7 @@ class InvocationContext:
def build_invocation_context( def build_invocation_context(
services: InvocationServices, services: InvocationServices,
data: InvocationContextData, data: InvocationContextData,
cancel_event: threading.Event, is_canceled: Callable[[], bool],
) -> InvocationContext: ) -> InvocationContext:
"""Builds the invocation context for a specific invocation execution. """Builds the invocation context for a specific invocation execution.
@ -619,7 +616,7 @@ def build_invocation_context(
tensors = TensorsInterface(services=services, data=data) tensors = TensorsInterface(services=services, data=data)
models = ModelsInterface(services=services, data=data) models = ModelsInterface(services=services, data=data)
config = ConfigInterface(services=services, data=data) config = ConfigInterface(services=services, data=data)
util = UtilInterface(services=services, data=data, cancel_event=cancel_event) util = UtilInterface(services=services, data=data, is_canceled=is_canceled)
conditioning = ConditioningInterface(services=services, data=data) conditioning = ConditioningInterface(services=services, data=data)
boards = BoardsInterface(services=services, data=data) boards = BoardsInterface(services=services, data=data)

View File

@ -42,7 +42,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_7()) migrator.register_migration(build_migration_7())
migrator.register_migration(build_migration_8(app_config=config)) migrator.register_migration(build_migration_8(app_config=config))
migrator.register_migration(build_migration_9()) migrator.register_migration(build_migration_9())
migrator.register_migration(build_migration_10(app_config=config, logger=logger)) migrator.register_migration(build_migration_10())
migrator.run_migrations() migrator.run_migrations()
return db return db

View File

@ -1,75 +1,35 @@
import shutil
import sqlite3 import sqlite3
from logging import Logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
LEGACY_CORE_MODELS = [
# OpenPose
"any/annotators/dwpose/yolox_l.onnx",
"any/annotators/dwpose/dw-ll_ucoco_384.onnx",
# DepthAnything
"any/annotators/depth_anything/depth_anything_vitl14.pth",
"any/annotators/depth_anything/depth_anything_vitb14.pth",
"any/annotators/depth_anything/depth_anything_vits14.pth",
# Lama inpaint
"core/misc/lama/lama.pt",
# RealESRGAN upscale
"core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
"core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
"core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
"core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
]
class Migration10Callback: class Migration10Callback:
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
self._app_config = app_config
self._logger = logger
def __call__(self, cursor: sqlite3.Cursor) -> None: def __call__(self, cursor: sqlite3.Cursor) -> None:
self._remove_convert_cache() self._update_error_cols(cursor)
self._remove_downloaded_models()
self._remove_unused_core_models()
def _remove_convert_cache(self) -> None: def _update_error_cols(self, cursor: sqlite3.Cursor) -> None:
"""Rename models/.cache to models/.convert_cache.""" """
self._logger.info("Removing .cache directory. Converted models will now be cached in .convert_cache.") - Adds `error_type` and `error_message` columns to the session queue table.
legacy_convert_path = self._app_config.root_path / "models" / ".cache" - Renames the `error` column to `error_traceback`.
shutil.rmtree(legacy_convert_path, ignore_errors=True) """
def _remove_downloaded_models(self) -> None: cursor.execute("ALTER TABLE session_queue ADD COLUMN error_type TEXT;")
"""Remove models from their old locations; they will re-download when needed.""" cursor.execute("ALTER TABLE session_queue ADD COLUMN error_message TEXT;")
self._logger.info( cursor.execute("ALTER TABLE session_queue RENAME COLUMN error TO error_traceback;")
"Removing legacy just-in-time models. Downloaded models will now be cached in .download_cache."
)
for model_path in LEGACY_CORE_MODELS:
legacy_dest_path = self._app_config.models_path / model_path
legacy_dest_path.unlink(missing_ok=True)
def _remove_unused_core_models(self) -> None:
"""Remove unused core models and their directories."""
self._logger.info("Removing defunct core models.")
for dir in ["face_restoration", "misc", "upscaling"]:
path_to_remove = self._app_config.models_path / "core" / dir
shutil.rmtree(path_to_remove, ignore_errors=True)
shutil.rmtree(self._app_config.models_path / "any" / "annotators", ignore_errors=True)
def build_migration_10(app_config: InvokeAIAppConfig, logger: Logger) -> Migration: def build_migration_10() -> Migration:
""" """
Build the migration from database version 9 to 10. Build the migration from database version 9 to 10.
This migration does the following: This migration does the following:
- Moves "core" models previously downloaded with download_with_progress_bar() into new - Adds `error_type` and `error_message` columns to the session queue table.
"models/.download_cache" directory. - Renames the `error` column to `error_traceback`.
- Renames "models/.cache" to "models/.convert_cache".
""" """
migration_10 = Migration( migration_10 = Migration(
from_version=9, from_version=9,
to_version=10, to_version=10,
callback=Migration10Callback(app_config=app_config, logger=logger), callback=Migration10Callback(),
) )
return migration_10 return migration_10

View File

@ -0,0 +1,77 @@
import shutil
import sqlite3
from logging import Logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
LEGACY_CORE_MODELS = [
# OpenPose
"any/annotators/dwpose/yolox_l.onnx",
"any/annotators/dwpose/dw-ll_ucoco_384.onnx",
# DepthAnything
"any/annotators/depth_anything/depth_anything_vitl14.pth",
"any/annotators/depth_anything/depth_anything_vitb14.pth",
"any/annotators/depth_anything/depth_anything_vits14.pth",
# Lama inpaint
"core/misc/lama/lama.pt",
# RealESRGAN upscale
"core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
"core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
"core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
"core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
]
class Migration11Callback:
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
self._app_config = app_config
self._logger = logger
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._remove_convert_cache()
self._remove_downloaded_models()
self._remove_unused_core_models()
def _remove_convert_cache(self) -> None:
"""Rename models/.cache to models/.convert_cache."""
self._logger.info("Removing .cache directory. Converted models will now be cached in .convert_cache.")
legacy_convert_path = self._app_config.root_path / "models" / ".cache"
shutil.rmtree(legacy_convert_path, ignore_errors=True)
def _remove_downloaded_models(self) -> None:
"""Remove models from their old locations; they will re-download when needed."""
self._logger.info(
"Removing legacy just-in-time models. Downloaded models will now be cached in .download_cache."
)
for model_path in LEGACY_CORE_MODELS:
legacy_dest_path = self._app_config.models_path / model_path
legacy_dest_path.unlink(missing_ok=True)
def _remove_unused_core_models(self) -> None:
"""Remove unused core models and their directories."""
self._logger.info("Removing defunct core models.")
for dir in ["face_restoration", "misc", "upscaling"]:
path_to_remove = self._app_config.models_path / "core" / dir
shutil.rmtree(path_to_remove, ignore_errors=True)
shutil.rmtree(self._app_config.models_path / "any" / "annotators", ignore_errors=True)
def build_migration_11(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
"""
Build the migration from database version 9 to 10.
This migration does the following:
- Moves "core" models previously downloaded with download_with_progress_bar() into new
"models/.download_cache" directory.
- Renames "models/.cache" to "models/.convert_cache".
- Adds `error_type` and `error_message` columns to the session queue table.
- Renames the `error` column to `error_traceback`.
"""
migration_11 = Migration(
from_version=10,
to_version=11,
callback=Migration11Callback(app_config=app_config, logger=logger),
)
return migration_11

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Callable from typing import TYPE_CHECKING, Callable, Optional
import torch import torch
from PIL import Image from PIL import Image
@ -13,8 +13,36 @@ if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.app.services.shared.invocation_context import InvocationContextData
# fast latents preview matrix for sdxl
# generated by @StAlKeR7779
SDXL_LATENT_RGB_FACTORS = [
# R G B
[0.3816, 0.4930, 0.5320],
[-0.3753, 0.1631, 0.1739],
[0.1770, 0.3588, -0.2048],
[-0.4350, -0.2644, -0.4289],
]
SDXL_SMOOTH_MATRIX = [
[0.0358, 0.0964, 0.0358],
[0.0964, 0.4711, 0.0964],
[0.0358, 0.0964, 0.0358],
]
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None): # origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
# these updated numbers for v1.5 are from @torridgristle
SD1_5_LATENT_RGB_FACTORS = [
# R G B
[0.3444, 0.1385, 0.0670], # L1
[0.1247, 0.4027, 0.1494], # L2
[-0.3192, 0.2513, 0.2103], # L3
[-0.1307, -0.1874, -0.7445], # L4
]
def sample_to_lowres_estimated_image(
samples: torch.Tensor, latent_rgb_factors: torch.Tensor, smooth_matrix: Optional[torch.Tensor] = None
):
latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors
if smooth_matrix is not None: if smooth_matrix is not None:
@ -47,64 +75,12 @@ def stable_diffusion_step_callback(
else: else:
sample = intermediate_state.latents sample = intermediate_state.latents
# TODO: This does not seem to be needed any more?
# # txt2img provides a Tensor in the step_callback
# # img2img provides a PipelineIntermediateState
# if isinstance(sample, PipelineIntermediateState):
# # this was an img2img
# print('img2img')
# latents = sample.latents
# step = sample.step
# else:
# print('txt2img')
# latents = sample
# step = intermediate_state.step
# TODO: only output a preview image when requested
if base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]: if base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
# fast latents preview matrix for sdxl sdxl_latent_rgb_factors = torch.tensor(SDXL_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
# generated by @StAlKeR7779 sdxl_smooth_matrix = torch.tensor(SDXL_SMOOTH_MATRIX, dtype=sample.dtype, device=sample.device)
sdxl_latent_rgb_factors = torch.tensor(
[
# R G B
[0.3816, 0.4930, 0.5320],
[-0.3753, 0.1631, 0.1739],
[0.1770, 0.3588, -0.2048],
[-0.4350, -0.2644, -0.4289],
],
dtype=sample.dtype,
device=sample.device,
)
sdxl_smooth_matrix = torch.tensor(
[
[0.0358, 0.0964, 0.0358],
[0.0964, 0.4711, 0.0964],
[0.0358, 0.0964, 0.0358],
],
dtype=sample.dtype,
device=sample.device,
)
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix) image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
else: else:
# origingally adapted from code by @erucipe and @keturn here: v1_5_latent_rgb_factors = torch.tensor(SD1_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
# these updated numbers for v1.5 are from @torridgristle
v1_5_latent_rgb_factors = torch.tensor(
[
# R G B
[0.3444, 0.1385, 0.0670], # L1
[0.1247, 0.4027, 0.1494], # L2
[-0.3192, 0.2513, 0.2103], # L3
[-0.1307, -0.1874, -0.7445], # L4
],
dtype=sample.dtype,
device=sample.device,
)
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors) image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
(width, height) = image.size (width, height) = image.size
@ -113,15 +89,9 @@ def stable_diffusion_step_callback(
dataURL = image_to_dataURL(image, image_format="JPEG") dataURL = image_to_dataURL(image, image_format="JPEG")
events.emit_generator_progress( events.emit_invocation_denoise_progress(
queue_id=context_data.queue_item.queue_id, context_data.queue_item,
queue_item_id=context_data.queue_item.item_id, context_data.invocation,
queue_batch_id=context_data.queue_item.batch_id, intermediate_state,
graph_execution_state_id=context_data.queue_item.session_id, ProgressImage(dataURL=dataURL, width=width, height=height),
node_id=context_data.invocation.id,
source_node_id=context_data.source_invocation_id,
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
step=intermediate_state.step,
order=intermediate_state.order,
total_steps=intermediate_state.total_steps,
) )

View File

@ -42,10 +42,26 @@ T = TypeVar("T")
@dataclass @dataclass
class CacheRecord(Generic[T]): class CacheRecord(Generic[T]):
"""Elements of the cache.""" """
Elements of the cache:
key: Unique key for each model, same as used in the models database.
model: Model in memory.
state_dict: A read-only copy of the model's state dict in RAM. It will be
used as a template for creating a copy in the VRAM.
size: Size of the model
loaded: True if the model's state dict is currently in VRAM
Before a model is executed, the state_dict template is copied into VRAM,
and then injected into the model. When the model is finished, the VRAM
copy of the state dict is deleted, and the RAM version is reinjected
into the model.
"""
key: str key: str
model: T model: T
device: torch.device
state_dict: Optional[Dict[str, torch.Tensor]]
size: int size: int
loaded: bool = False loaded: bool = False
_locks: int = 0 _locks: int = 0

View File

@ -20,7 +20,6 @@ context. Use like this:
import gc import gc
import math import math
import sys
import time import time
from contextlib import suppress from contextlib import suppress
from logging import Logger from logging import Logger
@ -163,7 +162,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
return return
size = calc_model_size_by_data(model) size = calc_model_size_by_data(model)
self.make_room(size) self.make_room(size)
cache_record = CacheRecord(key, model, size)
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
self._cached_models[key] = cache_record self._cached_models[key] = cache_record
self._cache_stack.append(key) self._cache_stack.append(key)
@ -253,21 +254,40 @@ class ModelCache(ModelCacheBase[AnyModel]):
May raise a torch.cuda.OutOfMemoryError May raise a torch.cuda.OutOfMemoryError
""" """
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}") self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
model = cache_entry.model source_device = cache_entry.device
source_device = model.device if hasattr(model, "device") else self.storage_device # Note: We compare device types only so that 'cuda' == 'cuda:0'.
# This would need to be revised to support multi-GPU.
if torch.device(source_device).type == torch.device(target_device).type: if torch.device(source_device).type == torch.device(target_device).type:
return return
if not hasattr(cache_entry.model, "to"):
return
# This roundabout method for moving the model around is done to avoid
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
# When moving to VRAM, we copy (not move) each element of the state dict from
# RAM to a new state dict in VRAM, and then inject it into the model.
# This operation is slightly faster than running `to()` on the whole model.
#
# When the model needs to be removed from VRAM we simply delete the copy
# of the state dict in VRAM, and reinject the state dict that is cached
# in RAM into the model. So this operation is very fast.
start_model_to_time = time.time() start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot() snapshot_before = self._capture_memory_snapshot()
try: try:
if hasattr(model, "to"): if cache_entry.state_dict is not None:
model.to(target_device) assert hasattr(cache_entry.model, "load_state_dict")
elif isinstance(model, dict): if target_device == self.storage_device:
for _, v in model.items(): cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
if hasattr(v, "to"): else:
v.to(target_device) new_dict: Dict[str, torch.Tensor] = {}
for k, v in cache_entry.state_dict.items():
new_dict[k] = v.to(torch.device(target_device), copy=True)
cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device)
cache_entry.device = target_device
except Exception as e: # blow away cache entry except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry) self._delete_cache_entry(cache_entry)
raise e raise e
@ -347,43 +367,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack): while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model_key = self._cache_stack[pos] model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key] cache_entry = self._cached_models[model_key]
refs = sys.getrefcount(cache_entry.model)
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
# https://docs.python.org/3/library/gc.html#gc.get_referrers
# manualy clear local variable references of just finished function calls
# for some reason python don't want to collect it even by gc.collect() immidiately
if refs > 2:
while True:
cleared = False
for referrer in gc.get_referrers(cache_entry.model):
if type(referrer).__name__ == "frame":
# RuntimeError: cannot clear an executing frame
with suppress(RuntimeError):
referrer.clear()
cleared = True
# break
# repeat if referrers changes(due to frame clear), else exit loop
if cleared:
gc.collect()
else:
break
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug( self.logger.debug(
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}," f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
f" refs: {refs}"
) )
# Expected refs: if not cache_entry.locked:
# 1 from cache_entry
# 1 from getrefcount function
# 1 from onnx runtime object
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
self.logger.debug( self.logger.debug(
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
) )

View File

@ -2,6 +2,7 @@
"accessibility": { "accessibility": {
"about": "About", "about": "About",
"createIssue": "Create Issue", "createIssue": "Create Issue",
"submitSupportTicket": "Submit Support Ticket",
"invokeProgressBar": "Invoke progress bar", "invokeProgressBar": "Invoke progress bar",
"menu": "Menu", "menu": "Menu",
"mode": "Mode", "mode": "Mode",
@ -146,7 +147,9 @@
"viewing": "Viewing", "viewing": "Viewing",
"viewingDesc": "Review images in a large gallery view", "viewingDesc": "Review images in a large gallery view",
"editing": "Editing", "editing": "Editing",
"editingDesc": "Edit on the Control Layers canvas" "editingDesc": "Edit on the Control Layers canvas",
"enabled": "Enabled",
"disabled": "Disabled"
}, },
"controlnet": { "controlnet": {
"controlAdapter_one": "Control Adapter", "controlAdapter_one": "Control Adapter",
@ -775,10 +778,14 @@
"cannotConnectToSelf": "Cannot connect to self", "cannotConnectToSelf": "Cannot connect to self",
"cannotDuplicateConnection": "Cannot create duplicate connections", "cannotDuplicateConnection": "Cannot create duplicate connections",
"cannotMixAndMatchCollectionItemTypes": "Cannot mix and match collection item types", "cannotMixAndMatchCollectionItemTypes": "Cannot mix and match collection item types",
"missingNode": "Missing invocation node",
"missingInvocationTemplate": "Missing invocation template",
"missingFieldTemplate": "Missing field template",
"nodePack": "Node pack", "nodePack": "Node pack",
"collection": "Collection", "collection": "Collection",
"collectionFieldType": "{{name}} Collection", "singleFieldType": "{{name}} (Single)",
"collectionOrScalarFieldType": "{{name}} Collection|Scalar", "collectionFieldType": "{{name}} (Collection)",
"collectionOrScalarFieldType": "{{name}} (Single or Collection)",
"colorCodeEdges": "Color-Code Edges", "colorCodeEdges": "Color-Code Edges",
"colorCodeEdgesHelp": "Color-code edges according to their connected fields", "colorCodeEdgesHelp": "Color-code edges according to their connected fields",
"connectionWouldCreateCycle": "Connection would create a cycle", "connectionWouldCreateCycle": "Connection would create a cycle",
@ -893,7 +900,10 @@
"zoomInNodes": "Zoom In", "zoomInNodes": "Zoom In",
"zoomOutNodes": "Zoom Out", "zoomOutNodes": "Zoom Out",
"betaDesc": "This invocation is in beta. Until it is stable, it may have breaking changes during app updates. We plan to support this invocation long-term.", "betaDesc": "This invocation is in beta. Until it is stable, it may have breaking changes during app updates. We plan to support this invocation long-term.",
"prototypeDesc": "This invocation is a prototype. It may have breaking changes during app updates and may be removed at any time." "prototypeDesc": "This invocation is a prototype. It may have breaking changes during app updates and may be removed at any time.",
"imageAccessError": "Unable to find image {{image_name}}, resetting to default",
"boardAccessError": "Unable to find board {{board_id}}, resetting to default",
"modelAccessError": "Unable to find model {{key}}, resetting to default"
}, },
"parameters": { "parameters": {
"aspect": "Aspect", "aspect": "Aspect",
@ -948,7 +958,7 @@
"controlAdapterIncompatibleBaseModel": "incompatible Control Adapter base model", "controlAdapterIncompatibleBaseModel": "incompatible Control Adapter base model",
"controlAdapterNoImageSelected": "no Control Adapter image selected", "controlAdapterNoImageSelected": "no Control Adapter image selected",
"controlAdapterImageNotProcessed": "Control Adapter image not processed", "controlAdapterImageNotProcessed": "Control Adapter image not processed",
"t2iAdapterIncompatibleDimensions": "T2I Adapter requires image dimension to be multiples of 64", "t2iAdapterIncompatibleDimensions": "T2I Adapter requires image dimension to be multiples of {{multiple}}",
"ipAdapterNoModelSelected": "no IP adapter selected", "ipAdapterNoModelSelected": "no IP adapter selected",
"ipAdapterIncompatibleBaseModel": "incompatible IP Adapter base model", "ipAdapterIncompatibleBaseModel": "incompatible IP Adapter base model",
"ipAdapterNoImageSelected": "no IP Adapter image selected", "ipAdapterNoImageSelected": "no IP Adapter image selected",
@ -1066,8 +1076,9 @@
}, },
"toast": { "toast": {
"addedToBoard": "Added to board", "addedToBoard": "Added to board",
"baseModelChangedCleared_one": "Base model changed, cleared or disabled {{count}} incompatible submodel", "baseModelChanged": "Base Model Changed",
"baseModelChangedCleared_other": "Base model changed, cleared or disabled {{count}} incompatible submodels", "baseModelChangedCleared_one": "Cleared or disabled {{count}} incompatible submodel",
"baseModelChangedCleared_other": "Cleared or disabled {{count}} incompatible submodels",
"canceled": "Processing Canceled", "canceled": "Processing Canceled",
"canvasCopiedClipboard": "Canvas Copied to Clipboard", "canvasCopiedClipboard": "Canvas Copied to Clipboard",
"canvasDownloaded": "Canvas Downloaded", "canvasDownloaded": "Canvas Downloaded",
@ -1088,10 +1099,17 @@
"metadataLoadFailed": "Failed to load metadata", "metadataLoadFailed": "Failed to load metadata",
"modelAddedSimple": "Model Added to Queue", "modelAddedSimple": "Model Added to Queue",
"modelImportCanceled": "Model Import Canceled", "modelImportCanceled": "Model Import Canceled",
"outOfMemoryError": "Out of Memory Error",
"outOfMemoryErrorDesc": "Your current generation settings exceed system capacity. Please adjust your settings and try again.",
"parameters": "Parameters", "parameters": "Parameters",
"parameterNotSet": "{{parameter}} not set", "parameterSet": "Parameter Recalled",
"parameterSet": "{{parameter}} set", "parameterSetDesc": "Recalled {{parameter}}",
"parametersNotSet": "Parameters Not Set", "parameterNotSet": "Parameter Not Recalled",
"parameterNotSetDesc": "Unable to recall {{parameter}}",
"parameterNotSetDescWithMessage": "Unable to recall {{parameter}}: {{message}}",
"parametersSet": "Parameters Recalled",
"parametersNotSet": "Parameters Not Recalled",
"errorCopied": "Error Copied",
"problemCopyingCanvas": "Problem Copying Canvas", "problemCopyingCanvas": "Problem Copying Canvas",
"problemCopyingCanvasDesc": "Unable to export base layer", "problemCopyingCanvasDesc": "Unable to export base layer",
"problemCopyingImage": "Unable to Copy Image", "problemCopyingImage": "Unable to Copy Image",
@ -1111,11 +1129,13 @@
"sentToImageToImage": "Sent To Image To Image", "sentToImageToImage": "Sent To Image To Image",
"sentToUnifiedCanvas": "Sent to Unified Canvas", "sentToUnifiedCanvas": "Sent to Unified Canvas",
"serverError": "Server Error", "serverError": "Server Error",
"sessionRef": "Session: {{sessionId}}",
"setAsCanvasInitialImage": "Set as canvas initial image", "setAsCanvasInitialImage": "Set as canvas initial image",
"setCanvasInitialImage": "Set canvas initial image", "setCanvasInitialImage": "Set canvas initial image",
"setControlImage": "Set as control image", "setControlImage": "Set as control image",
"setInitialImage": "Set as initial image", "setInitialImage": "Set as initial image",
"setNodeField": "Set as node field", "setNodeField": "Set as node field",
"somethingWentWrong": "Something Went Wrong",
"uploadFailed": "Upload failed", "uploadFailed": "Upload failed",
"uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image", "uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image",
"uploadInitialImage": "Upload Initial Image", "uploadInitialImage": "Upload Initial Image",
@ -1555,7 +1575,6 @@
"controlLayers": "Control Layers", "controlLayers": "Control Layers",
"globalMaskOpacity": "Global Mask Opacity", "globalMaskOpacity": "Global Mask Opacity",
"autoNegative": "Auto Negative", "autoNegative": "Auto Negative",
"toggleVisibility": "Toggle Layer Visibility",
"deletePrompt": "Delete Prompt", "deletePrompt": "Delete Prompt",
"resetRegion": "Reset Region", "resetRegion": "Reset Region",
"debugLayers": "Debug Layers", "debugLayers": "Debug Layers",

View File

@ -382,7 +382,7 @@
"canvasMerged": "Lienzo consolidado", "canvasMerged": "Lienzo consolidado",
"sentToImageToImage": "Enviar hacia Imagen a Imagen", "sentToImageToImage": "Enviar hacia Imagen a Imagen",
"sentToUnifiedCanvas": "Enviar hacia Lienzo Consolidado", "sentToUnifiedCanvas": "Enviar hacia Lienzo Consolidado",
"parametersNotSet": "Parámetros no establecidos", "parametersNotSet": "Parámetros no recuperados",
"metadataLoadFailed": "Error al cargar metadatos", "metadataLoadFailed": "Error al cargar metadatos",
"serverError": "Error en el servidor", "serverError": "Error en el servidor",
"canceled": "Procesando la cancelación", "canceled": "Procesando la cancelación",
@ -390,7 +390,8 @@
"uploadFailedInvalidUploadDesc": "Debe ser una sola imagen PNG o JPEG", "uploadFailedInvalidUploadDesc": "Debe ser una sola imagen PNG o JPEG",
"parameterSet": "Conjunto de parámetros", "parameterSet": "Conjunto de parámetros",
"parameterNotSet": "Parámetro no configurado", "parameterNotSet": "Parámetro no configurado",
"problemCopyingImage": "No se puede copiar la imagen" "problemCopyingImage": "No se puede copiar la imagen",
"errorCopied": "Error al copiar"
}, },
"tooltip": { "tooltip": {
"feature": { "feature": {

View File

@ -524,7 +524,20 @@
"missingNodeTemplate": "Modello di nodo mancante", "missingNodeTemplate": "Modello di nodo mancante",
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}} ingresso mancante", "missingInputForField": "{{nodeLabel}} -> {{fieldLabel}} ingresso mancante",
"missingFieldTemplate": "Modello di campo mancante", "missingFieldTemplate": "Modello di campo mancante",
"imageNotProcessedForControlAdapter": "L'immagine dell'adattatore di controllo #{{number}} non è stata elaborata" "imageNotProcessedForControlAdapter": "L'immagine dell'adattatore di controllo #{{number}} non è stata elaborata",
"layer": {
"initialImageNoImageSelected": "Nessuna immagine iniziale selezionata",
"t2iAdapterIncompatibleDimensions": "L'adattatore T2I richiede che la dimensione dell'immagine sia un multiplo di {{multiple}}",
"controlAdapterNoModelSelected": "Nessun modello di Adattatore di Controllo selezionato",
"controlAdapterIncompatibleBaseModel": "Il modello base dell'adattatore di controllo non è compatibile",
"controlAdapterNoImageSelected": "Nessuna immagine dell'adattatore di controllo selezionata",
"controlAdapterImageNotProcessed": "Immagine dell'adattatore di controllo non elaborata",
"ipAdapterNoModelSelected": "Nessun adattatore IP selezionato",
"ipAdapterIncompatibleBaseModel": "Il modello base dell'adattatore IP non è compatibile",
"ipAdapterNoImageSelected": "Nessuna immagine dell'adattatore IP selezionata",
"rgNoPromptsOrIPAdapters": "Nessun prompt o adattatore IP",
"rgNoRegion": "Nessuna regione selezionata"
}
}, },
"useCpuNoise": "Usa la CPU per generare rumore", "useCpuNoise": "Usa la CPU per generare rumore",
"iterations": "Iterazioni", "iterations": "Iterazioni",
@ -824,8 +837,8 @@
"unableToUpdateNodes_other": "Impossibile aggiornare {{count}} nodi", "unableToUpdateNodes_other": "Impossibile aggiornare {{count}} nodi",
"addLinearView": "Aggiungi alla vista Lineare", "addLinearView": "Aggiungi alla vista Lineare",
"unknownErrorValidatingWorkflow": "Errore sconosciuto durante la convalida del flusso di lavoro", "unknownErrorValidatingWorkflow": "Errore sconosciuto durante la convalida del flusso di lavoro",
"collectionFieldType": "{{name}} Raccolta", "collectionFieldType": "{{name}} (Raccolta)",
"collectionOrScalarFieldType": "{{name}} Raccolta|Scalare", "collectionOrScalarFieldType": "{{name}} (Singola o Raccolta)",
"nodeVersion": "Versione Nodo", "nodeVersion": "Versione Nodo",
"inputFieldTypeParseError": "Impossibile analizzare il tipo di campo di input {{node}}.{{field}} ({{message}})", "inputFieldTypeParseError": "Impossibile analizzare il tipo di campo di input {{node}}.{{field}} ({{message}})",
"unsupportedArrayItemType": "Tipo di elemento dell'array non supportato \"{{type}}\"", "unsupportedArrayItemType": "Tipo di elemento dell'array non supportato \"{{type}}\"",
@ -863,7 +876,13 @@
"edit": "Modifica", "edit": "Modifica",
"graph": "Grafico", "graph": "Grafico",
"showEdgeLabelsHelp": "Mostra etichette sui collegamenti, che indicano i nodi collegati", "showEdgeLabelsHelp": "Mostra etichette sui collegamenti, che indicano i nodi collegati",
"showEdgeLabels": "Mostra le etichette del collegamento" "showEdgeLabels": "Mostra le etichette del collegamento",
"cannotMixAndMatchCollectionItemTypes": "Impossibile combinare e abbinare i tipi di elementi della raccolta",
"noGraph": "Nessun grafico",
"missingNode": "Nodo di invocazione mancante",
"missingInvocationTemplate": "Modello di invocazione mancante",
"missingFieldTemplate": "Modello di campo mancante",
"singleFieldType": "{{name}} (Singola)"
}, },
"boards": { "boards": {
"autoAddBoard": "Aggiungi automaticamente bacheca", "autoAddBoard": "Aggiungi automaticamente bacheca",
@ -1034,7 +1053,16 @@
"graphFailedToQueue": "Impossibile mettere in coda il grafico", "graphFailedToQueue": "Impossibile mettere in coda il grafico",
"batchFieldValues": "Valori Campi Lotto", "batchFieldValues": "Valori Campi Lotto",
"time": "Tempo", "time": "Tempo",
"openQueue": "Apri coda" "openQueue": "Apri coda",
"iterations_one": "Iterazione",
"iterations_many": "Iterazioni",
"iterations_other": "Iterazioni",
"prompts_one": "Prompt",
"prompts_many": "Prompt",
"prompts_other": "Prompt",
"generations_one": "Generazione",
"generations_many": "Generazioni",
"generations_other": "Generazioni"
}, },
"models": { "models": {
"noMatchingModels": "Nessun modello corrispondente", "noMatchingModels": "Nessun modello corrispondente",
@ -1563,7 +1591,6 @@
"brushSize": "Dimensioni del pennello", "brushSize": "Dimensioni del pennello",
"globalMaskOpacity": "Opacità globale della maschera", "globalMaskOpacity": "Opacità globale della maschera",
"autoNegative": "Auto Negativo", "autoNegative": "Auto Negativo",
"toggleVisibility": "Attiva/disattiva la visibilità dei livelli",
"deletePrompt": "Cancella il prompt", "deletePrompt": "Cancella il prompt",
"debugLayers": "Debug dei Livelli", "debugLayers": "Debug dei Livelli",
"rectangle": "Rettangolo", "rectangle": "Rettangolo",

View File

@ -6,7 +6,7 @@
"settingsLabel": "Instellingen", "settingsLabel": "Instellingen",
"img2img": "Afbeelding naar afbeelding", "img2img": "Afbeelding naar afbeelding",
"unifiedCanvas": "Centraal canvas", "unifiedCanvas": "Centraal canvas",
"nodes": "Werkstroom-editor", "nodes": "Werkstromen",
"upload": "Upload", "upload": "Upload",
"load": "Laad", "load": "Laad",
"statusDisconnected": "Niet verbonden", "statusDisconnected": "Niet verbonden",
@ -34,7 +34,60 @@
"controlNet": "ControlNet", "controlNet": "ControlNet",
"imageFailedToLoad": "Kan afbeelding niet laden", "imageFailedToLoad": "Kan afbeelding niet laden",
"learnMore": "Meer informatie", "learnMore": "Meer informatie",
"advanced": "Uitgebreid" "advanced": "Uitgebreid",
"file": "Bestand",
"installed": "Geïnstalleerd",
"notInstalled": "Niet $t(common.installed)",
"simple": "Eenvoudig",
"somethingWentWrong": "Er ging iets mis",
"add": "Voeg toe",
"checkpoint": "Checkpoint",
"details": "Details",
"outputs": "Uitvoeren",
"save": "Bewaar",
"nextPage": "Volgende pagina",
"blue": "Blauw",
"alpha": "Alfa",
"red": "Rood",
"editor": "Editor",
"folder": "Map",
"format": "structuur",
"goTo": "Ga naar",
"template": "Sjabloon",
"input": "Invoer",
"loglevel": "Logboekniveau",
"safetensors": "Safetensors",
"saveAs": "Bewaar als",
"created": "Gemaakt",
"green": "Groen",
"tab": "Tab",
"positivePrompt": "Positieve prompt",
"negativePrompt": "Negatieve prompt",
"selected": "Geselecteerd",
"orderBy": "Sorteer op",
"prevPage": "Vorige pagina",
"beta": "Bèta",
"copyError": "$t(gallery.copy) Fout",
"toResolve": "Op te lossen",
"aboutDesc": "Gebruik je Invoke voor het werk? Kijk dan naar:",
"aboutHeading": "Creatieve macht voor jou",
"copy": "Kopieer",
"data": "Gegevens",
"or": "of",
"updated": "Bijgewerkt",
"outpaint": "outpainten",
"viewing": "Bekijken",
"viewingDesc": "Beoordeel afbeelding in een grote galerijweergave",
"editing": "Bewerken",
"editingDesc": "Bewerk op het canvas Stuurlagen",
"ai": "ai",
"inpaint": "inpainten",
"unknown": "Onbekend",
"delete": "Verwijder",
"direction": "Richting",
"error": "Fout",
"localSystem": "Lokaal systeem",
"unknownError": "Onbekende fout"
}, },
"gallery": { "gallery": {
"galleryImageSize": "Afbeeldingsgrootte", "galleryImageSize": "Afbeeldingsgrootte",
@ -310,10 +363,41 @@
"modelSyncFailed": "Synchronisatie modellen mislukt", "modelSyncFailed": "Synchronisatie modellen mislukt",
"modelDeleteFailed": "Model kon niet verwijderd worden", "modelDeleteFailed": "Model kon niet verwijderd worden",
"convertingModelBegin": "Model aan het converteren. Even geduld.", "convertingModelBegin": "Model aan het converteren. Even geduld.",
"predictionType": "Soort voorspelling (voor Stable Diffusion 2.x-modellen en incidentele Stable Diffusion 1.x-modellen)", "predictionType": "Soort voorspelling",
"advanced": "Uitgebreid", "advanced": "Uitgebreid",
"modelType": "Soort model", "modelType": "Soort model",
"vaePrecision": "Nauwkeurigheid VAE" "vaePrecision": "Nauwkeurigheid VAE",
"loraTriggerPhrases": "LoRA-triggerzinnen",
"urlOrLocalPathHelper": "URL's zouden moeten wijzen naar een los bestand. Lokale paden kunnen wijzen naar een los bestand of map voor een individueel Diffusers-model.",
"modelName": "Modelnaam",
"path": "Pad",
"triggerPhrases": "Triggerzinnen",
"typePhraseHere": "Typ zin hier in",
"useDefaultSettings": "Gebruik standaardinstellingen",
"modelImageDeleteFailed": "Fout bij verwijderen modelafbeelding",
"modelImageUpdated": "Modelafbeelding bijgewerkt",
"modelImageUpdateFailed": "Fout bij bijwerken modelafbeelding",
"noMatchingModels": "Geen overeenkomende modellen",
"scanPlaceholder": "Pad naar een lokale map",
"noModelsInstalled": "Geen modellen geïnstalleerd",
"noModelsInstalledDesc1": "Installeer modellen met de",
"noModelSelected": "Geen model geselecteerd",
"starterModels": "Beginnermodellen",
"textualInversions": "Tekstuele omkeringen",
"upcastAttention": "Upcast-aandacht",
"uploadImage": "Upload afbeelding",
"mainModelTriggerPhrases": "Triggerzinnen hoofdmodel",
"urlOrLocalPath": "URL of lokaal pad",
"scanFolderHelper": "De map zal recursief worden ingelezen voor modellen. Dit kan enige tijd in beslag nemen voor erg grote mappen.",
"simpleModelPlaceholder": "URL of pad naar een lokaal pad of Diffusers-map",
"modelSettings": "Modelinstellingen",
"pathToConfig": "Pad naar configuratie",
"prune": "Snoei",
"pruneTooltip": "Snoei voltooide importeringen uit wachtrij",
"repoVariant": "Repovariant",
"scanFolder": "Lees map in",
"scanResults": "Resultaten inlezen",
"source": "Bron"
}, },
"parameters": { "parameters": {
"images": "Afbeeldingen", "images": "Afbeeldingen",
@ -353,13 +437,13 @@
"copyImage": "Kopieer afbeelding", "copyImage": "Kopieer afbeelding",
"denoisingStrength": "Sterkte ontruisen", "denoisingStrength": "Sterkte ontruisen",
"scheduler": "Planner", "scheduler": "Planner",
"seamlessXAxis": "X-as", "seamlessXAxis": "Naadloze tegels in x-as",
"seamlessYAxis": "Y-as", "seamlessYAxis": "Naadloze tegels in y-as",
"clipSkip": "Overslaan CLIP", "clipSkip": "Overslaan CLIP",
"negativePromptPlaceholder": "Negatieve prompt", "negativePromptPlaceholder": "Negatieve prompt",
"controlNetControlMode": "Aansturingsmodus", "controlNetControlMode": "Aansturingsmodus",
"positivePromptPlaceholder": "Positieve prompt", "positivePromptPlaceholder": "Positieve prompt",
"maskBlur": "Vervaag", "maskBlur": "Vervaging van masker",
"invoke": { "invoke": {
"noNodesInGraph": "Geen knooppunten in graaf", "noNodesInGraph": "Geen knooppunten in graaf",
"noModelSelected": "Geen model ingesteld", "noModelSelected": "Geen model ingesteld",
@ -369,11 +453,25 @@
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}} invoer ontbreekt", "missingInputForField": "{{nodeLabel}} -> {{fieldLabel}} invoer ontbreekt",
"noControlImageForControlAdapter": "Controle-adapter #{{number}} heeft geen controle-afbeelding", "noControlImageForControlAdapter": "Controle-adapter #{{number}} heeft geen controle-afbeelding",
"noModelForControlAdapter": "Control-adapter #{{number}} heeft geen model ingesteld staan.", "noModelForControlAdapter": "Control-adapter #{{number}} heeft geen model ingesteld staan.",
"incompatibleBaseModelForControlAdapter": "Model van controle-adapter #{{number}} is ongeldig in combinatie met het hoofdmodel.", "incompatibleBaseModelForControlAdapter": "Model van controle-adapter #{{number}} is niet compatibel met het hoofdmodel.",
"systemDisconnected": "Systeem is niet verbonden", "systemDisconnected": "Systeem is niet verbonden",
"missingNodeTemplate": "Knooppuntsjabloon ontbreekt", "missingNodeTemplate": "Knooppuntsjabloon ontbreekt",
"missingFieldTemplate": "Veldsjabloon ontbreekt", "missingFieldTemplate": "Veldsjabloon ontbreekt",
"addingImagesTo": "Bezig met toevoegen van afbeeldingen aan" "addingImagesTo": "Bezig met toevoegen van afbeeldingen aan",
"layer": {
"initialImageNoImageSelected": "geen initiële afbeelding geselecteerd",
"controlAdapterNoModelSelected": "geen controle-adaptermodel geselecteerd",
"controlAdapterIncompatibleBaseModel": "niet-compatibele basismodel voor controle-adapter",
"controlAdapterNoImageSelected": "geen afbeelding voor controle-adapter geselecteerd",
"controlAdapterImageNotProcessed": "Afbeelding voor controle-adapter niet verwerkt",
"ipAdapterIncompatibleBaseModel": "niet-compatibele basismodel voor IP-adapter",
"ipAdapterNoImageSelected": "geen afbeelding voor IP-adapter geselecteerd",
"rgNoRegion": "geen gebied geselecteerd",
"rgNoPromptsOrIPAdapters": "geen tekstprompts of IP-adapters",
"t2iAdapterIncompatibleDimensions": "T2I-adapter vereist een afbeelding met afmetingen met een veelvoud van 64",
"ipAdapterNoModelSelected": "geen IP-adapter geselecteerd"
},
"imageNotProcessedForControlAdapter": "De afbeelding van controle-adapter #{{number}} is niet verwerkt"
}, },
"isAllowedToUpscale": { "isAllowedToUpscale": {
"useX2Model": "Afbeelding is te groot om te vergroten met het x4-model. Gebruik hiervoor het x2-model", "useX2Model": "Afbeelding is te groot om te vergroten met het x4-model. Gebruik hiervoor het x2-model",
@ -383,7 +481,26 @@
"useCpuNoise": "Gebruik CPU-ruis", "useCpuNoise": "Gebruik CPU-ruis",
"imageActions": "Afbeeldingshandeling", "imageActions": "Afbeeldingshandeling",
"iterations": "Iteraties", "iterations": "Iteraties",
"coherenceMode": "Modus" "coherenceMode": "Modus",
"infillColorValue": "Vulkleur",
"remixImage": "Meng afbeelding opnieuw",
"setToOptimalSize": "Optimaliseer grootte voor het model",
"setToOptimalSizeTooSmall": "$t(parameters.setToOptimalSize) (is mogelijk te klein)",
"aspect": "Beeldverhouding",
"infillMosaicTileWidth": "Breedte tegel",
"setToOptimalSizeTooLarge": "$t(parameters.setToOptimalSize) (is mogelijk te groot)",
"lockAspectRatio": "Zet beeldverhouding vast",
"infillMosaicTileHeight": "Hoogte tegel",
"globalNegativePromptPlaceholder": "Globale negatieve prompt",
"globalPositivePromptPlaceholder": "Globale positieve prompt",
"useSize": "Gebruik grootte",
"swapDimensions": "Wissel afmetingen om",
"globalSettings": "Globale instellingen",
"coherenceEdgeSize": "Randgrootte",
"coherenceMinDenoise": "Min. ontruising",
"infillMosaicMinColor": "Min. kleur",
"infillMosaicMaxColor": "Max. kleur",
"cfgRescaleMultiplier": "Vermenigvuldiger voor CFG-herschaling"
}, },
"settings": { "settings": {
"models": "Modellen", "models": "Modellen",
@ -410,7 +527,12 @@
"intermediatesCleared_one": "{{count}} tussentijdse afbeelding gewist", "intermediatesCleared_one": "{{count}} tussentijdse afbeelding gewist",
"intermediatesCleared_other": "{{count}} tussentijdse afbeeldingen gewist", "intermediatesCleared_other": "{{count}} tussentijdse afbeeldingen gewist",
"clearIntermediatesDesc1": "Als je tussentijdse afbeeldingen wist, dan wordt de staat hersteld van je canvas en van ControlNet.", "clearIntermediatesDesc1": "Als je tussentijdse afbeeldingen wist, dan wordt de staat hersteld van je canvas en van ControlNet.",
"intermediatesClearedFailed": "Fout bij wissen van tussentijdse afbeeldingen" "intermediatesClearedFailed": "Fout bij wissen van tussentijdse afbeeldingen",
"clearIntermediatesDisabled": "Wachtrij moet leeg zijn om tussentijdse afbeeldingen te kunnen leegmaken",
"enableInformationalPopovers": "Schakel informatieve hulpballonnen in",
"enableInvisibleWatermark": "Schakel onzichtbaar watermerk in",
"enableNSFWChecker": "Schakel NSFW-controle in",
"reloadingIn": "Opnieuw laden na"
}, },
"toast": { "toast": {
"uploadFailed": "Upload mislukt", "uploadFailed": "Upload mislukt",
@ -425,8 +547,8 @@
"connected": "Verbonden met server", "connected": "Verbonden met server",
"canceled": "Verwerking geannuleerd", "canceled": "Verwerking geannuleerd",
"uploadFailedInvalidUploadDesc": "Moet een enkele PNG- of JPEG-afbeelding zijn", "uploadFailedInvalidUploadDesc": "Moet een enkele PNG- of JPEG-afbeelding zijn",
"parameterNotSet": "Parameter niet ingesteld", "parameterNotSet": "{{parameter}} niet ingesteld",
"parameterSet": "Instellen parameters", "parameterSet": "{{parameter}} ingesteld",
"problemCopyingImage": "Kan Afbeelding Niet Kopiëren", "problemCopyingImage": "Kan Afbeelding Niet Kopiëren",
"baseModelChangedCleared_one": "Basismodel is gewijzigd: {{count}} niet-compatibel submodel weggehaald of uitgeschakeld", "baseModelChangedCleared_one": "Basismodel is gewijzigd: {{count}} niet-compatibel submodel weggehaald of uitgeschakeld",
"baseModelChangedCleared_other": "Basismodel is gewijzigd: {{count}} niet-compatibele submodellen weggehaald of uitgeschakeld", "baseModelChangedCleared_other": "Basismodel is gewijzigd: {{count}} niet-compatibele submodellen weggehaald of uitgeschakeld",
@ -443,11 +565,11 @@
"maskSavedAssets": "Masker bewaard in Assets", "maskSavedAssets": "Masker bewaard in Assets",
"problemDownloadingCanvas": "Fout bij downloaden van canvas", "problemDownloadingCanvas": "Fout bij downloaden van canvas",
"problemMergingCanvas": "Fout bij samenvoegen canvas", "problemMergingCanvas": "Fout bij samenvoegen canvas",
"setCanvasInitialImage": "Ingesteld als initiële canvasafbeelding", "setCanvasInitialImage": "Initiële canvasafbeelding ingesteld",
"imageUploaded": "Afbeelding geüpload", "imageUploaded": "Afbeelding geüpload",
"addedToBoard": "Toegevoegd aan bord", "addedToBoard": "Toegevoegd aan bord",
"workflowLoaded": "Werkstroom geladen", "workflowLoaded": "Werkstroom geladen",
"modelAddedSimple": "Model toegevoegd", "modelAddedSimple": "Model toegevoegd aan wachtrij",
"problemImportingMaskDesc": "Kan masker niet exporteren", "problemImportingMaskDesc": "Kan masker niet exporteren",
"problemCopyingCanvas": "Fout bij kopiëren canvas", "problemCopyingCanvas": "Fout bij kopiëren canvas",
"problemSavingCanvas": "Fout bij bewaren canvas", "problemSavingCanvas": "Fout bij bewaren canvas",
@ -459,7 +581,18 @@
"maskSentControlnetAssets": "Masker gestuurd naar ControlNet en Assets", "maskSentControlnetAssets": "Masker gestuurd naar ControlNet en Assets",
"canvasSavedGallery": "Canvas bewaard in galerij", "canvasSavedGallery": "Canvas bewaard in galerij",
"imageUploadFailed": "Fout bij uploaden afbeelding", "imageUploadFailed": "Fout bij uploaden afbeelding",
"problemImportingMask": "Fout bij importeren masker" "problemImportingMask": "Fout bij importeren masker",
"workflowDeleted": "Werkstroom verwijderd",
"invalidUpload": "Ongeldige upload",
"uploadInitialImage": "Initiële afbeelding uploaden",
"setAsCanvasInitialImage": "Ingesteld als initiële afbeelding voor canvas",
"problemRetrievingWorkflow": "Fout bij ophalen van werkstroom",
"parameters": "Parameters",
"modelImportCanceled": "Importeren model geannuleerd",
"problemDeletingWorkflow": "Fout bij verwijderen van werkstroom",
"prunedQueue": "Wachtrij gesnoeid",
"problemDownloadingImage": "Fout bij downloaden afbeelding",
"resetInitialImage": "Initiële afbeelding hersteld"
}, },
"tooltip": { "tooltip": {
"feature": { "feature": {
@ -533,7 +666,11 @@
"showOptionsPanel": "Toon zijscherm", "showOptionsPanel": "Toon zijscherm",
"menu": "Menu", "menu": "Menu",
"showGalleryPanel": "Toon deelscherm Galerij", "showGalleryPanel": "Toon deelscherm Galerij",
"loadMore": "Laad meer" "loadMore": "Laad meer",
"about": "Over",
"mode": "Modus",
"resetUI": "$t(accessibility.reset) UI",
"createIssue": "Maak probleem aan"
}, },
"nodes": { "nodes": {
"zoomOutNodes": "Uitzoomen", "zoomOutNodes": "Uitzoomen",
@ -547,7 +684,7 @@
"loadWorkflow": "Laad werkstroom", "loadWorkflow": "Laad werkstroom",
"downloadWorkflow": "Download JSON van werkstroom", "downloadWorkflow": "Download JSON van werkstroom",
"scheduler": "Planner", "scheduler": "Planner",
"missingTemplate": "Ontbrekende sjabloon", "missingTemplate": "Ongeldig knooppunt: knooppunt {{node}} van het soort {{type}} heeft een ontbrekend sjabloon (niet geïnstalleerd?)",
"workflowDescription": "Korte beschrijving", "workflowDescription": "Korte beschrijving",
"versionUnknown": " Versie onbekend", "versionUnknown": " Versie onbekend",
"noNodeSelected": "Geen knooppunt gekozen", "noNodeSelected": "Geen knooppunt gekozen",
@ -563,7 +700,7 @@
"integer": "Geheel getal", "integer": "Geheel getal",
"nodeTemplate": "Sjabloon knooppunt", "nodeTemplate": "Sjabloon knooppunt",
"nodeOpacity": "Dekking knooppunt", "nodeOpacity": "Dekking knooppunt",
"unableToLoadWorkflow": "Kan werkstroom niet valideren", "unableToLoadWorkflow": "Fout bij laden werkstroom",
"snapToGrid": "Lijn uit op raster", "snapToGrid": "Lijn uit op raster",
"noFieldsLinearview": "Geen velden toegevoegd aan lineaire weergave", "noFieldsLinearview": "Geen velden toegevoegd aan lineaire weergave",
"nodeSearch": "Zoek naar knooppunten", "nodeSearch": "Zoek naar knooppunten",
@ -614,11 +751,56 @@
"unknownField": "Onbekend veld", "unknownField": "Onbekend veld",
"colorCodeEdges": "Kleurgecodeerde randen", "colorCodeEdges": "Kleurgecodeerde randen",
"unknownNode": "Onbekend knooppunt", "unknownNode": "Onbekend knooppunt",
"mismatchedVersion": "Heeft niet-overeenkomende versie", "mismatchedVersion": "Ongeldig knooppunt: knooppunt {{node}} van het soort {{type}} heeft een niet-overeenkomende versie (probeer het bij te werken?)",
"addNodeToolTip": "Voeg knooppunt toe (Shift+A, spatie)", "addNodeToolTip": "Voeg knooppunt toe (Shift+A, spatie)",
"loadingNodes": "Bezig met laden van knooppunten...", "loadingNodes": "Bezig met laden van knooppunten...",
"snapToGridHelp": "Lijn knooppunten uit op raster bij verplaatsing", "snapToGridHelp": "Lijn knooppunten uit op raster bij verplaatsing",
"workflowSettings": "Instellingen werkstroomeditor" "workflowSettings": "Instellingen werkstroomeditor",
"addLinearView": "Voeg toe aan lineaire weergave",
"nodePack": "Knooppuntpakket",
"unknownInput": "Onbekende invoer: {{name}}",
"sourceNodeFieldDoesNotExist": "Ongeldige rand: bron-/uitvoerveld {{node}}.{{field}} bestaat niet",
"collectionFieldType": "Verzameling {{name}}",
"deletedInvalidEdge": "Ongeldige hoek {{source}} -> {{target}} verwijderd",
"graph": "Grafiek",
"targetNodeDoesNotExist": "Ongeldige rand: doel-/invoerknooppunt {{node}} bestaat niet",
"resetToDefaultValue": "Herstel naar standaardwaarden",
"editMode": "Bewerk in Werkstroom-editor",
"showEdgeLabels": "Toon randlabels",
"showEdgeLabelsHelp": "Toon labels aan randen, waarmee de verbonden knooppunten mee worden aangegeven",
"clearWorkflowDesc2": "Je huidige werkstroom heeft niet-bewaarde wijzigingen.",
"unableToParseFieldType": "fout bij bepalen soort veld",
"sourceNodeDoesNotExist": "Ongeldige rand: bron-/uitvoerknooppunt {{node}} bestaat niet",
"unsupportedArrayItemType": "niet-ondersteunde soort van het array-onderdeel \"{{type}}\"",
"targetNodeFieldDoesNotExist": "Ongeldige rand: doel-/invoerveld {{node}}.{{field}} bestaat niet",
"reorderLinearView": "Herorden lineaire weergave",
"newWorkflowDesc": "Een nieuwe werkstroom aanmaken?",
"collectionOrScalarFieldType": "Verzameling|scalair {{name}}",
"newWorkflow": "Nieuwe werkstroom",
"unknownErrorValidatingWorkflow": "Onbekende fout bij valideren werkstroom",
"unsupportedAnyOfLength": "te veel union-leden ({{count}})",
"unknownOutput": "Onbekende uitvoer: {{name}}",
"viewMode": "Gebruik in lineaire weergave",
"unableToExtractSchemaNameFromRef": "fout bij het extraheren van de schemanaam via de ref",
"unsupportedMismatchedUnion": "niet-overeenkomende soort CollectionOrScalar met basissoorten {{firstType}} en {{secondType}}",
"unknownNodeType": "Onbekend soort knooppunt",
"edit": "Bewerk",
"updateAllNodes": "Werk knooppunten bij",
"allNodesUpdated": "Alle knooppunten bijgewerkt",
"nodeVersion": "Knooppuntversie",
"newWorkflowDesc2": "Je huidige werkstroom heeft niet-bewaarde wijzigingen.",
"clearWorkflow": "Maak werkstroom leeg",
"clearWorkflowDesc": "Deze werkstroom leegmaken en met een nieuwe beginnen?",
"inputFieldTypeParseError": "Fout bij bepalen van het soort invoerveld {{node}}.{{field}} ({{message}})",
"outputFieldTypeParseError": "Fout bij het bepalen van het soort uitvoerveld {{node}}.{{field}} ({{message}})",
"unableToExtractEnumOptions": "fout bij extraheren enumeratie-opties",
"unknownFieldType": "Soort $t(nodes.unknownField): {{type}}",
"unableToGetWorkflowVersion": "Fout bij ophalen schemaversie van werkstroom",
"betaDesc": "Deze uitvoering is in bèta. Totdat deze stabiel is kunnen er wijzigingen voorkomen gedurende app-updates die zaken kapotmaken. We zijn van plan om deze uitvoering op lange termijn te gaan ondersteunen.",
"prototypeDesc": "Deze uitvoering is een prototype. Er kunnen wijzigingen voorkomen gedurende app-updates die zaken kapotmaken. Deze kunnen op een willekeurig moment verwijderd worden.",
"noFieldsViewMode": "Deze werkstroom heeft geen geselecteerde velden om te tonen. Bekijk de volledige werkstroom om de waarden te configureren.",
"unableToUpdateNodes_one": "Fout bij bijwerken van {{count}} knooppunt",
"unableToUpdateNodes_other": "Fout bij bijwerken van {{count}} knooppunten"
}, },
"controlnet": { "controlnet": {
"amult": "a_mult", "amult": "a_mult",
@ -691,9 +873,28 @@
"canny": "Canny", "canny": "Canny",
"depthZoeDescription": "Genereer diepteblad via Zoe", "depthZoeDescription": "Genereer diepteblad via Zoe",
"hedDescription": "Herkenning van holistisch-geneste randen", "hedDescription": "Herkenning van holistisch-geneste randen",
"setControlImageDimensions": "Stel afmetingen controle-afbeelding in op B/H", "setControlImageDimensions": "Kopieer grootte naar B/H (optimaliseer voor model)",
"scribble": "Krabbel", "scribble": "Krabbel",
"maxFaces": "Max. gezichten" "maxFaces": "Max. gezichten",
"dwOpenpose": "DW Openpose",
"depthAnything": "Depth Anything",
"base": "Basis",
"hands": "Handen",
"selectCLIPVisionModel": "Selecteer een CLIP Vision-model",
"modelSize": "Modelgrootte",
"small": "Klein",
"large": "Groot",
"resizeSimple": "Wijzig grootte (eenvoudig)",
"beginEndStepPercentShort": "Begin-/eind-%",
"depthAnythingDescription": "Genereren dieptekaart d.m.v. de techniek Depth Anything",
"face": "Gezicht",
"body": "Lichaam",
"dwOpenposeDescription": "Schatting menselijke pose d.m.v. DW Openpose",
"ipAdapterMethod": "Methode",
"full": "Volledig",
"style": "Alleen stijl",
"composition": "Alleen samenstelling",
"setControlImageDimensionsForce": "Kopieer grootte naar B/H (negeer model)"
}, },
"dynamicPrompts": { "dynamicPrompts": {
"seedBehaviour": { "seedBehaviour": {
@ -706,7 +907,10 @@
"maxPrompts": "Max. prompts", "maxPrompts": "Max. prompts",
"promptsWithCount_one": "{{count}} prompt", "promptsWithCount_one": "{{count}} prompt",
"promptsWithCount_other": "{{count}} prompts", "promptsWithCount_other": "{{count}} prompts",
"dynamicPrompts": "Dynamische prompts" "dynamicPrompts": "Dynamische prompts",
"showDynamicPrompts": "Toon dynamische prompts",
"loading": "Genereren van dynamische prompts...",
"promptsPreview": "Voorvertoning prompts"
}, },
"popovers": { "popovers": {
"noiseUseCPU": { "noiseUseCPU": {
@ -719,7 +923,7 @@
}, },
"paramScheduler": { "paramScheduler": {
"paragraphs": [ "paragraphs": [
"De planner bepaalt hoe ruis per iteratie wordt toegevoegd aan een afbeelding of hoe een monster wordt bijgewerkt op basis van de uitvoer van een model." "De planner gebruikt gedurende het genereringsproces."
], ],
"heading": "Planner" "heading": "Planner"
}, },
@ -806,8 +1010,8 @@
}, },
"clipSkip": { "clipSkip": {
"paragraphs": [ "paragraphs": [
"Kies hoeveel CLIP-modellagen je wilt overslaan.", "Aantal over te slaan CLIP-modellagen.",
"Bepaalde modellen werken beter met bepaalde Overslaan CLIP-instellingen." "Bepaalde modellen zijn beter geschikt met bepaalde Overslaan CLIP-instellingen."
], ],
"heading": "Overslaan CLIP" "heading": "Overslaan CLIP"
}, },
@ -991,17 +1195,26 @@
"denoisingStrength": "Sterkte ontruising", "denoisingStrength": "Sterkte ontruising",
"refinermodel": "Verfijningsmodel", "refinermodel": "Verfijningsmodel",
"posAestheticScore": "Positieve esthetische score", "posAestheticScore": "Positieve esthetische score",
"concatPromptStyle": "Plak prompt- en stijltekst aan elkaar", "concatPromptStyle": "Koppelen van prompt en stijl",
"loading": "Bezig met laden...", "loading": "Bezig met laden...",
"steps": "Stappen", "steps": "Stappen",
"posStylePrompt": "Positieve-stijlprompt" "posStylePrompt": "Positieve-stijlprompt",
"freePromptStyle": "Handmatige stijlprompt",
"refinerSteps": "Aantal stappen verfijner"
}, },
"models": { "models": {
"noMatchingModels": "Geen overeenkomend modellen", "noMatchingModels": "Geen overeenkomend modellen",
"loading": "bezig met laden", "loading": "bezig met laden",
"noMatchingLoRAs": "Geen overeenkomende LoRA's", "noMatchingLoRAs": "Geen overeenkomende LoRA's",
"noModelsAvailable": "Geen modellen beschikbaar", "noModelsAvailable": "Geen modellen beschikbaar",
"selectModel": "Kies een model" "selectModel": "Kies een model",
"noLoRAsInstalled": "Geen LoRA's geïnstalleerd",
"noRefinerModelsInstalled": "Geen SDXL-verfijningsmodellen geïnstalleerd",
"defaultVAE": "Standaard-VAE",
"lora": "LoRA",
"esrganModel": "ESRGAN-model",
"addLora": "Voeg LoRA toe",
"concepts": "Concepten"
}, },
"boards": { "boards": {
"autoAddBoard": "Voeg automatisch bord toe", "autoAddBoard": "Voeg automatisch bord toe",
@ -1019,7 +1232,13 @@
"downloadBoard": "Download bord", "downloadBoard": "Download bord",
"changeBoard": "Wijzig bord", "changeBoard": "Wijzig bord",
"loading": "Bezig met laden...", "loading": "Bezig met laden...",
"clearSearch": "Maak zoekopdracht leeg" "clearSearch": "Maak zoekopdracht leeg",
"deleteBoard": "Verwijder bord",
"deleteBoardAndImages": "Verwijder bord en afbeeldingen",
"deleteBoardOnly": "Verwijder alleen bord",
"deletedBoardsCannotbeRestored": "Verwijderde borden kunnen niet worden hersteld",
"movingImagesToBoard_one": "Verplaatsen van {{count}} afbeelding naar bord:",
"movingImagesToBoard_other": "Verplaatsen van {{count}} afbeeldingen naar bord:"
}, },
"invocationCache": { "invocationCache": {
"disable": "Schakel uit", "disable": "Schakel uit",
@ -1036,5 +1255,39 @@
"clear": "Wis", "clear": "Wis",
"maxCacheSize": "Max. grootte cache", "maxCacheSize": "Max. grootte cache",
"cacheSize": "Grootte cache" "cacheSize": "Grootte cache"
},
"accordions": {
"generation": {
"title": "Genereren"
},
"image": {
"title": "Afbeelding"
},
"advanced": {
"title": "Geavanceerd",
"options": "$t(accordions.advanced.title) Opties"
},
"control": {
"title": "Besturing"
},
"compositing": {
"title": "Samenstellen",
"coherenceTab": "Coherentiefase",
"infillTab": "Invullen"
}
},
"hrf": {
"upscaleMethod": "Opschaalmethode",
"metadata": {
"strength": "Sterkte oplossing voor hoge resolutie",
"method": "Methode oplossing voor hoge resolutie",
"enabled": "Oplossing voor hoge resolutie ingeschakeld"
},
"hrf": "Oplossing voor hoge resolutie",
"enableHrf": "Schakel oplossing in voor hoge resolutie"
},
"prompt": {
"addPromptTrigger": "Voeg prompttrigger toe",
"compatibleEmbeddings": "Compatibele embeddings"
} }
} }

View File

@ -1594,7 +1594,6 @@
"deleteAll": "Удалить всё", "deleteAll": "Удалить всё",
"addLayer": "Добавить слой", "addLayer": "Добавить слой",
"moveToFront": "На передний план", "moveToFront": "На передний план",
"toggleVisibility": "Переключить видимость слоя",
"addPositivePrompt": "Добавить $t(common.positivePrompt)", "addPositivePrompt": "Добавить $t(common.positivePrompt)",
"addIPAdapter": "Добавить $t(common.ipAdapter)", "addIPAdapter": "Добавить $t(common.ipAdapter)",
"regionalGuidanceLayer": "$t(controlLayers.regionalGuidance) $t(unifiedCanvas.layer)", "regionalGuidanceLayer": "$t(controlLayers.regionalGuidance) $t(unifiedCanvas.layer)",

View File

@ -21,10 +21,10 @@ import i18n from 'i18n';
import { size } from 'lodash-es'; import { size } from 'lodash-es';
import { memo, useCallback, useEffect } from 'react'; import { memo, useCallback, useEffect } from 'react';
import { ErrorBoundary } from 'react-error-boundary'; import { ErrorBoundary } from 'react-error-boundary';
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback'; import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
import PreselectedImage from './PreselectedImage'; import PreselectedImage from './PreselectedImage';
import Toaster from './Toaster';
const DEFAULT_CONFIG = {}; const DEFAULT_CONFIG = {};
@ -46,6 +46,7 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => {
useSocketIO(); useSocketIO();
useGlobalModifiersInit(); useGlobalModifiersInit();
useGlobalHotkeys(); useGlobalHotkeys();
useGetOpenAPISchemaQuery();
const { dropzone, isHandlingUpload, setIsHandlingUpload } = useFullscreenDropzone(); const { dropzone, isHandlingUpload, setIsHandlingUpload } = useFullscreenDropzone();
@ -94,7 +95,6 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => {
<DeleteImageModal /> <DeleteImageModal />
<ChangeBoardModal /> <ChangeBoardModal />
<DynamicPromptsModal /> <DynamicPromptsModal />
<Toaster />
<PreselectedImage selectedImage={selectedImage} /> <PreselectedImage selectedImage={selectedImage} />
</ErrorBoundary> </ErrorBoundary>
); );

View File

@ -1,5 +1,8 @@
import { Button, Flex, Heading, Link, Text, useToast } from '@invoke-ai/ui-library'; import { Button, Flex, Heading, Image, Link, Text } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { toast } from 'features/toast/toast';
import newGithubIssueUrl from 'new-github-issue-url'; import newGithubIssueUrl from 'new-github-issue-url';
import InvokeLogoYellow from 'public/assets/images/invoke-symbol-ylw-lrg.svg';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold, PiArrowSquareOutBold, PiCopyBold } from 'react-icons/pi'; import { PiArrowCounterClockwiseBold, PiArrowSquareOutBold, PiCopyBold } from 'react-icons/pi';
@ -11,31 +14,39 @@ type Props = {
}; };
const AppErrorBoundaryFallback = ({ error, resetErrorBoundary }: Props) => { const AppErrorBoundaryFallback = ({ error, resetErrorBoundary }: Props) => {
const toast = useToast();
const { t } = useTranslation(); const { t } = useTranslation();
const isLocal = useAppSelector((s) => s.config.isLocal);
const handleCopy = useCallback(() => { const handleCopy = useCallback(() => {
const text = JSON.stringify(serializeError(error), null, 2); const text = JSON.stringify(serializeError(error), null, 2);
navigator.clipboard.writeText(`\`\`\`\n${text}\n\`\`\``); navigator.clipboard.writeText(`\`\`\`\n${text}\n\`\`\``);
toast({ toast({
title: 'Error Copied', id: 'ERROR_COPIED',
title: t('toast.errorCopied'),
}); });
}, [error, toast]); }, [error, t]);
const url = useMemo( const url = useMemo(() => {
() => if (isLocal) {
newGithubIssueUrl({ return newGithubIssueUrl({
user: 'invoke-ai', user: 'invoke-ai',
repo: 'InvokeAI', repo: 'InvokeAI',
template: 'BUG_REPORT.yml', template: 'BUG_REPORT.yml',
title: `[bug]: ${error.name}: ${error.message}`, title: `[bug]: ${error.name}: ${error.message}`,
}), });
[error.message, error.name] } else {
); return 'https://support.invoke.ai/support/tickets/new';
}
}, [error.message, error.name, isLocal]);
return ( return (
<Flex layerStyle="body" w="100vw" h="100vh" alignItems="center" justifyContent="center" p={4}> <Flex layerStyle="body" w="100vw" h="100vh" alignItems="center" justifyContent="center" p={4}>
<Flex layerStyle="first" flexDir="column" borderRadius="base" justifyContent="center" gap={8} p={16}> <Flex layerStyle="first" flexDir="column" borderRadius="base" justifyContent="center" gap={8} p={16}>
<Heading>{t('common.somethingWentWrong')}</Heading> <Flex alignItems="center" gap="2">
<Image src={InvokeLogoYellow} alt="invoke-logo" w="24px" h="24px" minW="24px" minH="24px" userSelect="none" />
<Heading fontSize="2xl">{t('common.somethingWentWrong')}</Heading>
</Flex>
<Flex <Flex
layerStyle="second" layerStyle="second"
px={8} px={8}
@ -57,7 +68,9 @@ const AppErrorBoundaryFallback = ({ error, resetErrorBoundary }: Props) => {
{t('common.copyError')} {t('common.copyError')}
</Button> </Button>
<Link href={url} isExternal> <Link href={url} isExternal>
<Button leftIcon={<PiArrowSquareOutBold />}>{t('accessibility.createIssue')}</Button> <Button leftIcon={<PiArrowSquareOutBold />}>
{isLocal ? t('accessibility.createIssue') : t('accessibility.submitSupportTicket')}
</Button>
</Link> </Link>
</Flex> </Flex>
</Flex> </Flex>

View File

@ -1,44 +0,0 @@
import { useToast } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { addToast, clearToastQueue } from 'features/system/store/systemSlice';
import type { MakeToastArg } from 'features/system/util/makeToast';
import { makeToast } from 'features/system/util/makeToast';
import { memo, useCallback, useEffect } from 'react';
/**
* Logical component. Watches the toast queue and makes toasts when the queue is not empty.
* @returns null
*/
const Toaster = () => {
const dispatch = useAppDispatch();
const toastQueue = useAppSelector((s) => s.system.toastQueue);
const toast = useToast();
useEffect(() => {
toastQueue.forEach((t) => {
toast(t);
});
toastQueue.length > 0 && dispatch(clearToastQueue());
}, [dispatch, toast, toastQueue]);
return null;
};
/**
* Returns a function that can be used to make a toast.
* @example
* const toaster = useAppToaster();
* toaster('Hello world!');
* toaster({ title: 'Hello world!', status: 'success' });
* @returns A function that can be used to make a toast.
* @see makeToast
* @see MakeToastArg
* @see UseToastOptions
*/
export const useAppToaster = () => {
const dispatch = useAppDispatch();
const toaster = useCallback((arg: MakeToastArg) => dispatch(addToast(makeToast(arg))), [dispatch]);
return toaster;
};
export default memo(Toaster);

View File

@ -6,8 +6,8 @@ import { useAppDispatch } from 'app/store/storeHooks';
import type { MapStore } from 'nanostores'; import type { MapStore } from 'nanostores';
import { atom, map } from 'nanostores'; import { atom, map } from 'nanostores';
import { useEffect, useMemo } from 'react'; import { useEffect, useMemo } from 'react';
import { setEventListeners } from 'services/events/setEventListeners';
import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types'; import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types';
import { setEventListeners } from 'services/events/util/setEventListeners';
import type { ManagerOptions, Socket, SocketOptions } from 'socket.io-client'; import type { ManagerOptions, Socket, SocketOptions } from 'socket.io-client';
import { io } from 'socket.io-client'; import { io } from 'socket.io-client';

View File

@ -35,28 +35,22 @@ import { addImageUploadedFulfilledListener } from 'app/store/middleware/listener
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected'; import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded'; import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged'; import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged';
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected'; import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected';
import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected'; import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected';
import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress'; import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress';
import { addGraphExecutionStateCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete';
import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete'; import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete';
import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError'; import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError';
import { addInvocationRetrievalErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError';
import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted'; import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted';
import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall'; import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall';
import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad'; import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad';
import { addSocketQueueItemStatusChangedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged'; import { addSocketQueueItemStatusChangedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged';
import { addSessionRetrievalErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError';
import { addSocketSubscribedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed';
import { addSocketUnsubscribedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed';
import { addStagingAreaImageSavedListener } from 'app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved'; import { addStagingAreaImageSavedListener } from 'app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved';
import { addUpdateAllNodesRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested'; import { addUpdateAllNodesRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested';
import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested'; import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested'; import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
import type { AppDispatch, RootState } from 'app/store/store'; import type { AppDispatch, RootState } from 'app/store/store';
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
export type AppStartListening = TypedStartListening<RootState, AppDispatch>; export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
@ -104,18 +98,13 @@ addCommitStagingAreaImageListener(startAppListening);
// Socket.IO // Socket.IO
addGeneratorProgressEventListener(startAppListening); addGeneratorProgressEventListener(startAppListening);
addGraphExecutionStateCompleteEventListener(startAppListening);
addInvocationCompleteEventListener(startAppListening); addInvocationCompleteEventListener(startAppListening);
addInvocationErrorEventListener(startAppListening); addInvocationErrorEventListener(startAppListening);
addInvocationStartedEventListener(startAppListening); addInvocationStartedEventListener(startAppListening);
addSocketConnectedEventListener(startAppListening); addSocketConnectedEventListener(startAppListening);
addSocketDisconnectedEventListener(startAppListening); addSocketDisconnectedEventListener(startAppListening);
addSocketSubscribedEventListener(startAppListening);
addSocketUnsubscribedEventListener(startAppListening);
addModelLoadEventListener(startAppListening); addModelLoadEventListener(startAppListening);
addModelInstallEventListener(startAppListening); addModelInstallEventListener(startAppListening);
addSessionRetrievalErrorEventListener(startAppListening);
addInvocationRetrievalErrorEventListener(startAppListening);
addSocketQueueItemStatusChangedEventListener(startAppListening); addSocketQueueItemStatusChangedEventListener(startAppListening);
addBulkDownloadListeners(startAppListening); addBulkDownloadListeners(startAppListening);

View File

@ -8,7 +8,7 @@ import {
resetCanvas, resetCanvas,
setInitialCanvasImage, setInitialCanvasImage,
} from 'features/canvas/store/canvasSlice'; } from 'features/canvas/store/canvasSlice';
import { addToast } from 'features/system/store/systemSlice'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue'; import { queueApi } from 'services/api/endpoints/queue';
@ -30,22 +30,20 @@ export const addCommitStagingAreaImageListener = (startAppListening: AppStartLis
req.reset(); req.reset();
if (canceled > 0) { if (canceled > 0) {
log.debug(`Canceled ${canceled} canvas batches`); log.debug(`Canceled ${canceled} canvas batches`);
dispatch( toast({
addToast({ id: 'CANCEL_BATCH_SUCCEEDED',
title: t('queue.cancelBatchSucceeded'), title: t('queue.cancelBatchSucceeded'),
status: 'success', status: 'success',
}) });
);
} }
dispatch(canvasBatchIdsReset()); dispatch(canvasBatchIdsReset());
} catch { } catch {
log.error('Failed to cancel canvas batches'); log.error('Failed to cancel canvas batches');
dispatch( toast({
addToast({ id: 'CANCEL_BATCH_FAILED',
title: t('queue.cancelBatchFailed'), title: t('queue.cancelBatchFailed'),
status: 'error', status: 'error',
}) });
);
} }
}, },
}); });

View File

@ -1,8 +1,8 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize'; import { parseify } from 'common/util/serialize';
import { toast } from 'common/util/toast';
import { zPydanticValidationError } from 'features/system/store/zodSchemas'; import { zPydanticValidationError } from 'features/system/store/zodSchemas';
import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { truncate, upperFirst } from 'lodash-es'; import { truncate, upperFirst } from 'lodash-es';
import { queueApi } from 'services/api/endpoints/queue'; import { queueApi } from 'services/api/endpoints/queue';
@ -16,18 +16,15 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
const arg = action.meta.arg.originalArgs; const arg = action.meta.arg.originalArgs;
logger('queue').debug({ enqueueResult: parseify(response) }, 'Batch enqueued'); logger('queue').debug({ enqueueResult: parseify(response) }, 'Batch enqueued');
if (!toast.isActive('batch-queued')) { toast({
toast({ id: 'QUEUE_BATCH_SUCCEEDED',
id: 'batch-queued', title: t('queue.batchQueued'),
title: t('queue.batchQueued'), status: 'success',
description: t('queue.batchQueuedDesc', { description: t('queue.batchQueuedDesc', {
count: response.enqueued, count: response.enqueued,
direction: arg.prepend ? t('queue.front') : t('queue.back'), direction: arg.prepend ? t('queue.front') : t('queue.back'),
}), }),
duration: 1000, });
status: 'success',
});
}
}, },
}); });
@ -40,9 +37,10 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
if (!response) { if (!response) {
toast({ toast({
id: 'QUEUE_BATCH_FAILED',
title: t('queue.batchFailedToQueue'), title: t('queue.batchFailedToQueue'),
status: 'error', status: 'error',
description: 'Unknown Error', description: t('common.unknownError'),
}); });
logger('queue').error({ batchConfig: parseify(arg), error: parseify(response) }, t('queue.batchFailedToQueue')); logger('queue').error({ batchConfig: parseify(arg), error: parseify(response) }, t('queue.batchFailedToQueue'));
return; return;
@ -52,7 +50,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
if (result.success) { if (result.success) {
result.data.data.detail.map((e) => { result.data.data.detail.map((e) => {
toast({ toast({
id: 'batch-failed-to-queue', id: 'QUEUE_BATCH_FAILED',
title: truncate(upperFirst(e.msg), { length: 128 }), title: truncate(upperFirst(e.msg), { length: 128 }),
status: 'error', status: 'error',
description: truncate( description: truncate(
@ -64,9 +62,10 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
}); });
} else if (response.status !== 403) { } else if (response.status !== 403) {
toast({ toast({
id: 'QUEUE_BATCH_FAILED',
title: t('queue.batchFailedToQueue'), title: t('queue.batchFailedToQueue'),
description: t('common.unknownError'),
status: 'error', status: 'error',
description: t('common.unknownError'),
}); });
} }
logger('queue').error({ batchConfig: parseify(arg), error: parseify(response) }, t('queue.batchFailedToQueue')); logger('queue').error({ batchConfig: parseify(arg), error: parseify(response) }, t('queue.batchFailedToQueue'));

View File

@ -1,13 +1,12 @@
import type { UseToastOptions } from '@invoke-ai/ui-library';
import { ExternalLink } from '@invoke-ai/ui-library'; import { ExternalLink } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { toast } from 'common/util/toast'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
import { import {
socketBulkDownloadCompleted, socketBulkDownloadComplete,
socketBulkDownloadFailed, socketBulkDownloadError,
socketBulkDownloadStarted, socketBulkDownloadStarted,
} from 'services/events/actions'; } from 'services/events/actions';
@ -28,7 +27,6 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
// Show the response message if it exists, otherwise show the default message // Show the response message if it exists, otherwise show the default message
description: action.payload.response || t('gallery.bulkDownloadRequestedDesc'), description: action.payload.response || t('gallery.bulkDownloadRequestedDesc'),
duration: null, duration: null,
isClosable: true,
}); });
}, },
}); });
@ -40,9 +38,9 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
// There isn't any toast to update if we get this event. // There isn't any toast to update if we get this event.
toast({ toast({
id: 'BULK_DOWNLOAD_REQUEST_FAILED',
title: t('gallery.bulkDownloadRequestFailed'), title: t('gallery.bulkDownloadRequestFailed'),
status: 'success', status: 'error',
isClosable: true,
}); });
}, },
}); });
@ -56,7 +54,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
}); });
startAppListening({ startAppListening({
actionCreator: socketBulkDownloadCompleted, actionCreator: socketBulkDownloadComplete,
effect: async (action) => { effect: async (action) => {
log.debug(action.payload.data, 'Bulk download preparation completed'); log.debug(action.payload.data, 'Bulk download preparation completed');
@ -65,7 +63,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
// TODO(psyche): This URL may break in in some environments (e.g. Nvidia workbench) but we need to test it first // TODO(psyche): This URL may break in in some environments (e.g. Nvidia workbench) but we need to test it first
const url = `/api/v1/images/download/${bulk_download_item_name}`; const url = `/api/v1/images/download/${bulk_download_item_name}`;
const toastOptions: UseToastOptions = { toast({
id: bulk_download_item_name, id: bulk_download_item_name,
title: t('gallery.bulkDownloadReady', 'Download ready'), title: t('gallery.bulkDownloadReady', 'Download ready'),
status: 'success', status: 'success',
@ -77,38 +75,24 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
/> />
), ),
duration: null, duration: null,
isClosable: true, });
};
if (toast.isActive(bulk_download_item_name)) {
toast.update(bulk_download_item_name, toastOptions);
} else {
toast(toastOptions);
}
}, },
}); });
startAppListening({ startAppListening({
actionCreator: socketBulkDownloadFailed, actionCreator: socketBulkDownloadError,
effect: async (action) => { effect: async (action) => {
log.debug(action.payload.data, 'Bulk download preparation failed'); log.debug(action.payload.data, 'Bulk download preparation failed');
const { bulk_download_item_name } = action.payload.data; const { bulk_download_item_name } = action.payload.data;
const toastOptions: UseToastOptions = { toast({
id: bulk_download_item_name, id: bulk_download_item_name,
title: t('gallery.bulkDownloadFailed'), title: t('gallery.bulkDownloadFailed'),
status: 'error', status: 'error',
description: action.payload.data.error, description: action.payload.data.error,
duration: null, duration: null,
isClosable: true, });
};
if (toast.isActive(bulk_download_item_name)) {
toast.update(bulk_download_item_name, toastOptions);
} else {
toast(toastOptions);
}
}, },
}); });
}; };

View File

@ -2,14 +2,14 @@ import { $logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasCopiedToClipboard } from 'features/canvas/store/actions'; import { canvasCopiedToClipboard } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice';
import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard'; import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard';
import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
export const addCanvasCopiedToClipboardListener = (startAppListening: AppStartListening) => { export const addCanvasCopiedToClipboardListener = (startAppListening: AppStartListening) => {
startAppListening({ startAppListening({
actionCreator: canvasCopiedToClipboard, actionCreator: canvasCopiedToClipboard,
effect: async (action, { dispatch, getState }) => { effect: async (action, { getState }) => {
const moduleLog = $logger.get().child({ namespace: 'canvasCopiedToClipboardListener' }); const moduleLog = $logger.get().child({ namespace: 'canvasCopiedToClipboardListener' });
const state = getState(); const state = getState();
@ -19,22 +19,20 @@ export const addCanvasCopiedToClipboardListener = (startAppListening: AppStartLi
copyBlobToClipboard(blob); copyBlobToClipboard(blob);
} catch (err) { } catch (err) {
moduleLog.error(String(err)); moduleLog.error(String(err));
dispatch( toast({
addToast({ id: 'CANVAS_COPY_FAILED',
title: t('toast.problemCopyingCanvas'), title: t('toast.problemCopyingCanvas'),
description: t('toast.problemCopyingCanvasDesc'), description: t('toast.problemCopyingCanvasDesc'),
status: 'error', status: 'error',
}) });
);
return; return;
} }
dispatch( toast({
addToast({ id: 'CANVAS_COPY_SUCCEEDED',
title: t('toast.canvasCopiedClipboard'), title: t('toast.canvasCopiedClipboard'),
status: 'success', status: 'success',
}) });
);
}, },
}); });
}; };

View File

@ -3,13 +3,13 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { canvasDownloadedAsImage } from 'features/canvas/store/actions'; import { canvasDownloadedAsImage } from 'features/canvas/store/actions';
import { downloadBlob } from 'features/canvas/util/downloadBlob'; import { downloadBlob } from 'features/canvas/util/downloadBlob';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
export const addCanvasDownloadedAsImageListener = (startAppListening: AppStartListening) => { export const addCanvasDownloadedAsImageListener = (startAppListening: AppStartListening) => {
startAppListening({ startAppListening({
actionCreator: canvasDownloadedAsImage, actionCreator: canvasDownloadedAsImage,
effect: async (action, { dispatch, getState }) => { effect: async (action, { getState }) => {
const moduleLog = $logger.get().child({ namespace: 'canvasSavedToGalleryListener' }); const moduleLog = $logger.get().child({ namespace: 'canvasSavedToGalleryListener' });
const state = getState(); const state = getState();
@ -18,18 +18,17 @@ export const addCanvasDownloadedAsImageListener = (startAppListening: AppStartLi
blob = await getBaseLayerBlob(state); blob = await getBaseLayerBlob(state);
} catch (err) { } catch (err) {
moduleLog.error(String(err)); moduleLog.error(String(err));
dispatch( toast({
addToast({ id: 'CANVAS_DOWNLOAD_FAILED',
title: t('toast.problemDownloadingCanvas'), title: t('toast.problemDownloadingCanvas'),
description: t('toast.problemDownloadingCanvasDesc'), description: t('toast.problemDownloadingCanvasDesc'),
status: 'error', status: 'error',
}) });
);
return; return;
} }
downloadBlob(blob, 'canvas.png'); downloadBlob(blob, 'canvas.png');
dispatch(addToast({ title: t('toast.canvasDownloaded'), status: 'success' })); toast({ id: 'CANVAS_DOWNLOAD_SUCCEEDED', title: t('toast.canvasDownloaded'), status: 'success' });
}, },
}); });
}; };

View File

@ -3,7 +3,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { canvasImageToControlAdapter } from 'features/canvas/store/actions'; import { canvasImageToControlAdapter } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice'; import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { addToast } from 'features/system/store/systemSlice'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
@ -20,13 +20,12 @@ export const addCanvasImageToControlNetListener = (startAppListening: AppStartLi
blob = await getBaseLayerBlob(state, true); blob = await getBaseLayerBlob(state, true);
} catch (err) { } catch (err) {
log.error(String(err)); log.error(String(err));
dispatch( toast({
addToast({ id: 'PROBLEM_SAVING_CANVAS',
title: t('toast.problemSavingCanvas'), title: t('toast.problemSavingCanvas'),
description: t('toast.problemSavingCanvasDesc'), description: t('toast.problemSavingCanvasDesc'),
status: 'error', status: 'error',
}) });
);
return; return;
} }
@ -43,7 +42,7 @@ export const addCanvasImageToControlNetListener = (startAppListening: AppStartLi
crop_visible: false, crop_visible: false,
postUploadAction: { postUploadAction: {
type: 'TOAST', type: 'TOAST',
toastOptions: { title: t('toast.canvasSentControlnetAssets') }, title: t('toast.canvasSentControlnetAssets'),
}, },
}) })
).unwrap(); ).unwrap();

View File

@ -2,7 +2,7 @@ import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasMaskSavedToGallery } from 'features/canvas/store/actions'; import { canvasMaskSavedToGallery } from 'features/canvas/store/actions';
import { getCanvasData } from 'features/canvas/util/getCanvasData'; import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { addToast } from 'features/system/store/systemSlice'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
@ -29,13 +29,12 @@ export const addCanvasMaskSavedToGalleryListener = (startAppListening: AppStartL
if (!maskBlob) { if (!maskBlob) {
log.error('Problem getting mask layer blob'); log.error('Problem getting mask layer blob');
dispatch( toast({
addToast({ id: 'PROBLEM_SAVING_MASK',
title: t('toast.problemSavingMask'), title: t('toast.problemSavingMask'),
description: t('toast.problemSavingMaskDesc'), description: t('toast.problemSavingMaskDesc'),
status: 'error', status: 'error',
}) });
);
return; return;
} }
@ -52,7 +51,7 @@ export const addCanvasMaskSavedToGalleryListener = (startAppListening: AppStartL
crop_visible: true, crop_visible: true,
postUploadAction: { postUploadAction: {
type: 'TOAST', type: 'TOAST',
toastOptions: { title: t('toast.maskSavedAssets') }, title: t('toast.maskSavedAssets'),
}, },
}) })
); );

View File

@ -3,7 +3,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { canvasMaskToControlAdapter } from 'features/canvas/store/actions'; import { canvasMaskToControlAdapter } from 'features/canvas/store/actions';
import { getCanvasData } from 'features/canvas/util/getCanvasData'; import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice'; import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { addToast } from 'features/system/store/systemSlice'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
@ -30,13 +30,12 @@ export const addCanvasMaskToControlNetListener = (startAppListening: AppStartLis
if (!maskBlob) { if (!maskBlob) {
log.error('Problem getting mask layer blob'); log.error('Problem getting mask layer blob');
dispatch( toast({
addToast({ id: 'PROBLEM_IMPORTING_MASK',
title: t('toast.problemImportingMask'), title: t('toast.problemImportingMask'),
description: t('toast.problemImportingMaskDesc'), description: t('toast.problemImportingMaskDesc'),
status: 'error', status: 'error',
}) });
);
return; return;
} }
@ -53,7 +52,7 @@ export const addCanvasMaskToControlNetListener = (startAppListening: AppStartLis
crop_visible: false, crop_visible: false,
postUploadAction: { postUploadAction: {
type: 'TOAST', type: 'TOAST',
toastOptions: { title: t('toast.maskSentControlnetAssets') }, title: t('toast.maskSentControlnetAssets'),
}, },
}) })
).unwrap(); ).unwrap();

View File

@ -4,7 +4,7 @@ import { canvasMerged } from 'features/canvas/store/actions';
import { $canvasBaseLayer } from 'features/canvas/store/canvasNanostore'; import { $canvasBaseLayer } from 'features/canvas/store/canvasNanostore';
import { setMergedCanvas } from 'features/canvas/store/canvasSlice'; import { setMergedCanvas } from 'features/canvas/store/canvasSlice';
import { getFullBaseLayerBlob } from 'features/canvas/util/getFullBaseLayerBlob'; import { getFullBaseLayerBlob } from 'features/canvas/util/getFullBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
@ -17,13 +17,12 @@ export const addCanvasMergedListener = (startAppListening: AppStartListening) =>
if (!blob) { if (!blob) {
moduleLog.error('Problem getting base layer blob'); moduleLog.error('Problem getting base layer blob');
dispatch( toast({
addToast({ id: 'PROBLEM_MERGING_CANVAS',
title: t('toast.problemMergingCanvas'), title: t('toast.problemMergingCanvas'),
description: t('toast.problemMergingCanvasDesc'), description: t('toast.problemMergingCanvasDesc'),
status: 'error', status: 'error',
}) });
);
return; return;
} }
@ -31,13 +30,12 @@ export const addCanvasMergedListener = (startAppListening: AppStartListening) =>
if (!canvasBaseLayer) { if (!canvasBaseLayer) {
moduleLog.error('Problem getting canvas base layer'); moduleLog.error('Problem getting canvas base layer');
dispatch( toast({
addToast({ id: 'PROBLEM_MERGING_CANVAS',
title: t('toast.problemMergingCanvas'), title: t('toast.problemMergingCanvas'),
description: t('toast.problemMergingCanvasDesc'), description: t('toast.problemMergingCanvasDesc'),
status: 'error', status: 'error',
}) });
);
return; return;
} }
@ -54,7 +52,7 @@ export const addCanvasMergedListener = (startAppListening: AppStartListening) =>
is_intermediate: true, is_intermediate: true,
postUploadAction: { postUploadAction: {
type: 'TOAST', type: 'TOAST',
toastOptions: { title: t('toast.canvasMerged') }, title: t('toast.canvasMerged'),
}, },
}) })
).unwrap(); ).unwrap();

View File

@ -1,8 +1,9 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { canvasSavedToGallery } from 'features/canvas/store/actions'; import { canvasSavedToGallery } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
@ -18,13 +19,12 @@ export const addCanvasSavedToGalleryListener = (startAppListening: AppStartListe
blob = await getBaseLayerBlob(state); blob = await getBaseLayerBlob(state);
} catch (err) { } catch (err) {
log.error(String(err)); log.error(String(err));
dispatch( toast({
addToast({ id: 'CANVAS_SAVE_FAILED',
title: t('toast.problemSavingCanvas'), title: t('toast.problemSavingCanvas'),
description: t('toast.problemSavingCanvasDesc'), description: t('toast.problemSavingCanvasDesc'),
status: 'error', status: 'error',
}) });
);
return; return;
} }
@ -41,7 +41,10 @@ export const addCanvasSavedToGalleryListener = (startAppListening: AppStartListe
crop_visible: true, crop_visible: true,
postUploadAction: { postUploadAction: {
type: 'TOAST', type: 'TOAST',
toastOptions: { title: t('toast.canvasSavedGallery') }, title: t('toast.canvasSavedGallery'),
},
metadata: {
_canvas_objects: parseify(state.canvas.layerState.objects),
}, },
}) })
); );

View File

@ -14,8 +14,9 @@ import {
} from 'features/controlLayers/store/controlLayersSlice'; } from 'features/controlLayers/store/controlLayersSlice';
import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters'; import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters';
import { isImageOutput } from 'features/nodes/types/common'; import { isImageOutput } from 'features/nodes/types/common';
import { addToast } from 'features/system/store/systemSlice'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { isEqual } from 'lodash-es';
import { getImageDTO } from 'services/api/endpoints/images'; import { getImageDTO } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue'; import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig } from 'services/api/types'; import type { BatchConfig } from 'services/api/types';
@ -47,8 +48,10 @@ const cancelProcessorBatch = async (dispatch: AppDispatch, layerId: string, batc
export const addControlAdapterPreprocessor = (startAppListening: AppStartListening) => { export const addControlAdapterPreprocessor = (startAppListening: AppStartListening) => {
startAppListening({ startAppListening({
matcher, matcher,
effect: async (action, { dispatch, getState, cancelActiveListeners, delay, take, signal }) => { effect: async (action, { dispatch, getState, getOriginalState, cancelActiveListeners, delay, take, signal }) => {
const layerId = caLayerRecalled.match(action) ? action.payload.id : action.payload.layerId; const layerId = caLayerRecalled.match(action) ? action.payload.id : action.payload.layerId;
const state = getState();
const originalState = getOriginalState();
// Cancel any in-progress instances of this listener // Cancel any in-progress instances of this listener
cancelActiveListeners(); cancelActiveListeners();
@ -57,21 +60,33 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
// Delay before starting actual work // Delay before starting actual work
await delay(DEBOUNCE_MS); await delay(DEBOUNCE_MS);
// Double-check that we are still eligible for processing
const state = getState();
const layer = state.controlLayers.present.layers.filter(isControlAdapterLayer).find((l) => l.id === layerId); const layer = state.controlLayers.present.layers.filter(isControlAdapterLayer).find((l) => l.id === layerId);
// If we have no image or there is no processor config, bail
if (!layer) { if (!layer) {
return; return;
} }
// We should only process if the processor settings or image have changed
const originalLayer = originalState.controlLayers.present.layers
.filter(isControlAdapterLayer)
.find((l) => l.id === layerId);
const originalImage = originalLayer?.controlAdapter.image;
const originalConfig = originalLayer?.controlAdapter.processorConfig;
const image = layer.controlAdapter.image; const image = layer.controlAdapter.image;
const config = layer.controlAdapter.processorConfig; const config = layer.controlAdapter.processorConfig;
if (isEqual(config, originalConfig) && isEqual(image, originalImage)) {
// Neither config nor image have changed, we can bail
return;
}
if (!image || !config) { if (!image || !config) {
// The user has reset the image or config, so we should clear the processed image // - If we have no image, we have nothing to process
// - If we have no processor config, we have nothing to process
// Clear the processed image and bail
dispatch(caLayerProcessedImageChanged({ layerId, imageDTO: null })); dispatch(caLayerProcessedImageChanged({ layerId, imageDTO: null }));
return;
} }
// At this point, the user has stopped fiddling with the processor settings and there is a processor selected. // At this point, the user has stopped fiddling with the processor settings and there is a processor selected.
@ -81,8 +96,8 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
cancelProcessorBatch(dispatch, layerId, layer.controlAdapter.processorPendingBatchId); cancelProcessorBatch(dispatch, layerId, layer.controlAdapter.processorPendingBatchId);
} }
// @ts-expect-error: TS isn't able to narrow the typing of buildNode and `config` will error... // TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now
const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config); const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config as never);
const enqueueBatchArg: BatchConfig = { const enqueueBatchArg: BatchConfig = {
prepend: true, prepend: true,
batch: { batch: {
@ -118,8 +133,8 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
const [invocationCompleteAction] = await take( const [invocationCompleteAction] = await take(
(action): action is ReturnType<typeof socketInvocationComplete> => (action): action is ReturnType<typeof socketInvocationComplete> =>
socketInvocationComplete.match(action) && socketInvocationComplete.match(action) &&
action.payload.data.queue_batch_id === enqueueResult.batch.batch_id && action.payload.data.batch_id === enqueueResult.batch.batch_id &&
action.payload.data.source_node_id === processorNode.id action.payload.data.invocation_source_id === processorNode.id
); );
// We still have to check the output type // We still have to check the output type
@ -159,12 +174,11 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
} }
} }
dispatch( toast({
addToast({ id: 'GRAPH_QUEUE_FAILED',
title: t('queue.graphFailedToQueue'), title: t('queue.graphFailedToQueue'),
status: 'error', status: 'error',
}) });
);
} }
} finally { } finally {
req.reset(); req.reset();

View File

@ -10,7 +10,7 @@ import {
} from 'features/controlAdapters/store/controlAdaptersSlice'; } from 'features/controlAdapters/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types'; import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
import { isImageOutput } from 'features/nodes/types/common'; import { isImageOutput } from 'features/nodes/types/common';
import { addToast } from 'features/system/store/systemSlice'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue'; import { queueApi } from 'services/api/endpoints/queue';
@ -69,8 +69,8 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL
const [invocationCompleteAction] = await take( const [invocationCompleteAction] = await take(
(action): action is ReturnType<typeof socketInvocationComplete> => (action): action is ReturnType<typeof socketInvocationComplete> =>
socketInvocationComplete.match(action) && socketInvocationComplete.match(action) &&
action.payload.data.queue_batch_id === enqueueResult.batch.batch_id && action.payload.data.batch_id === enqueueResult.batch.batch_id &&
action.payload.data.source_node_id === nodeId action.payload.data.invocation_source_id === nodeId
); );
// We still have to check the output type // We still have to check the output type
@ -108,12 +108,11 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL
} }
} }
dispatch( toast({
addToast({ id: 'GRAPH_QUEUE_FAILED',
title: t('queue.graphFailedToQueue'), title: t('queue.graphFailedToQueue'),
status: 'error', status: 'error',
}) });
);
} }
}, },
}); });

View File

@ -1,4 +1,3 @@
import type { UseToastOptions } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
@ -14,7 +13,7 @@ import {
} from 'features/controlLayers/store/controlLayersSlice'; } from 'features/controlLayers/store/controlLayersSlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice'; import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { addToast } from 'features/system/store/systemSlice'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { omit } from 'lodash-es'; import { omit } from 'lodash-es';
import { boardsApi } from 'services/api/endpoints/boards'; import { boardsApi } from 'services/api/endpoints/boards';
@ -42,16 +41,17 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
return; return;
} }
const DEFAULT_UPLOADED_TOAST: UseToastOptions = { const DEFAULT_UPLOADED_TOAST = {
id: 'IMAGE_UPLOADED',
title: t('toast.imageUploaded'), title: t('toast.imageUploaded'),
status: 'success', status: 'success',
}; } as const;
// default action - just upload and alert user // default action - just upload and alert user
if (postUploadAction?.type === 'TOAST') { if (postUploadAction?.type === 'TOAST') {
const { toastOptions } = postUploadAction;
if (!autoAddBoardId || autoAddBoardId === 'none') { if (!autoAddBoardId || autoAddBoardId === 'none') {
dispatch(addToast({ ...DEFAULT_UPLOADED_TOAST, ...toastOptions })); const title = postUploadAction.title || DEFAULT_UPLOADED_TOAST.title;
toast({ ...DEFAULT_UPLOADED_TOAST, title });
} else { } else {
// Add this image to the board // Add this image to the board
dispatch( dispatch(
@ -70,24 +70,20 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
? `${t('toast.addedToBoard')} ${board.board_name}` ? `${t('toast.addedToBoard')} ${board.board_name}`
: `${t('toast.addedToBoard')} ${autoAddBoardId}`; : `${t('toast.addedToBoard')} ${autoAddBoardId}`;
dispatch( toast({
addToast({ ...DEFAULT_UPLOADED_TOAST,
...DEFAULT_UPLOADED_TOAST, description,
description, });
})
);
} }
return; return;
} }
if (postUploadAction?.type === 'SET_CANVAS_INITIAL_IMAGE') { if (postUploadAction?.type === 'SET_CANVAS_INITIAL_IMAGE') {
dispatch(setInitialCanvasImage(imageDTO, selectOptimalDimension(state))); dispatch(setInitialCanvasImage(imageDTO, selectOptimalDimension(state)));
dispatch( toast({
addToast({ ...DEFAULT_UPLOADED_TOAST,
...DEFAULT_UPLOADED_TOAST, description: t('toast.setAsCanvasInitialImage'),
description: t('toast.setAsCanvasInitialImage'), });
})
);
return; return;
} }
@ -105,68 +101,56 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
controlImage: imageDTO.image_name, controlImage: imageDTO.image_name,
}) })
); );
dispatch( toast({
addToast({ ...DEFAULT_UPLOADED_TOAST,
...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage'),
description: t('toast.setControlImage'), });
})
);
return; return;
} }
if (postUploadAction?.type === 'SET_CA_LAYER_IMAGE') { if (postUploadAction?.type === 'SET_CA_LAYER_IMAGE') {
const { layerId } = postUploadAction; const { layerId } = postUploadAction;
dispatch(caLayerImageChanged({ layerId, imageDTO })); dispatch(caLayerImageChanged({ layerId, imageDTO }));
dispatch( toast({
addToast({ ...DEFAULT_UPLOADED_TOAST,
...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage'),
description: t('toast.setControlImage'), });
})
);
} }
if (postUploadAction?.type === 'SET_IPA_LAYER_IMAGE') { if (postUploadAction?.type === 'SET_IPA_LAYER_IMAGE') {
const { layerId } = postUploadAction; const { layerId } = postUploadAction;
dispatch(ipaLayerImageChanged({ layerId, imageDTO })); dispatch(ipaLayerImageChanged({ layerId, imageDTO }));
dispatch( toast({
addToast({ ...DEFAULT_UPLOADED_TOAST,
...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage'),
description: t('toast.setControlImage'), });
})
);
} }
if (postUploadAction?.type === 'SET_RG_LAYER_IP_ADAPTER_IMAGE') { if (postUploadAction?.type === 'SET_RG_LAYER_IP_ADAPTER_IMAGE') {
const { layerId, ipAdapterId } = postUploadAction; const { layerId, ipAdapterId } = postUploadAction;
dispatch(rgLayerIPAdapterImageChanged({ layerId, ipAdapterId, imageDTO })); dispatch(rgLayerIPAdapterImageChanged({ layerId, ipAdapterId, imageDTO }));
dispatch( toast({
addToast({ ...DEFAULT_UPLOADED_TOAST,
...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage'),
description: t('toast.setControlImage'), });
})
);
} }
if (postUploadAction?.type === 'SET_II_LAYER_IMAGE') { if (postUploadAction?.type === 'SET_II_LAYER_IMAGE') {
const { layerId } = postUploadAction; const { layerId } = postUploadAction;
dispatch(iiLayerImageChanged({ layerId, imageDTO })); dispatch(iiLayerImageChanged({ layerId, imageDTO }));
dispatch( toast({
addToast({ ...DEFAULT_UPLOADED_TOAST,
...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage'),
description: t('toast.setControlImage'), });
})
);
} }
if (postUploadAction?.type === 'SET_NODES_IMAGE') { if (postUploadAction?.type === 'SET_NODES_IMAGE') {
const { nodeId, fieldName } = postUploadAction; const { nodeId, fieldName } = postUploadAction;
dispatch(fieldImageValueChanged({ nodeId, fieldName, value: imageDTO })); dispatch(fieldImageValueChanged({ nodeId, fieldName, value: imageDTO }));
dispatch( toast({
addToast({ ...DEFAULT_UPLOADED_TOAST,
...DEFAULT_UPLOADED_TOAST, description: `${t('toast.setNodeField')} ${fieldName}`,
description: `${t('toast.setNodeField')} ${fieldName}`, });
})
);
return; return;
} }
}, },
@ -174,7 +158,7 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
startAppListening({ startAppListening({
matcher: imagesApi.endpoints.uploadImage.matchRejected, matcher: imagesApi.endpoints.uploadImage.matchRejected,
effect: (action, { dispatch }) => { effect: (action) => {
const log = logger('images'); const log = logger('images');
const sanitizedData = { const sanitizedData = {
arg: { arg: {
@ -183,13 +167,11 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
}, },
}; };
log.error({ ...sanitizedData }, 'Image upload failed'); log.error({ ...sanitizedData }, 'Image upload failed');
dispatch( toast({
addToast({ title: t('toast.imageUploadFailed'),
title: t('toast.imageUploadFailed'), description: action.error.message,
description: action.error.message, status: 'error',
status: 'error', });
})
);
}, },
}); });
}; };

View File

@ -8,8 +8,7 @@ import { loraRemoved } from 'features/lora/store/loraSlice';
import { modelSelected } from 'features/parameters/store/actions'; import { modelSelected } from 'features/parameters/store/actions';
import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice'; import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice';
import { zParameterModel } from 'features/parameters/types/parameterSchemas'; import { zParameterModel } from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice'; import { toast } from 'features/toast/toast';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next'; import { t } from 'i18next';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
@ -60,16 +59,14 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
}); });
if (modelsCleared > 0) { if (modelsCleared > 0) {
dispatch( toast({
addToast( id: 'BASE_MODEL_CHANGED',
makeToast({ title: t('toast.baseModelChanged'),
title: t('toast.baseModelChangedCleared', { description: t('toast.baseModelChangedCleared', {
count: modelsCleared, count: modelsCleared,
}), }),
status: 'warning', status: 'warning',
}) });
)
);
} }
} }

View File

@ -19,8 +19,7 @@ import {
isParameterWidth, isParameterWidth,
zParameterVAEModel, zParameterVAEModel,
} from 'features/parameters/types/parameterSchemas'; } from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice'; import { toast } from 'features/toast/toast';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next'; import { t } from 'i18next';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models'; import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import { isNonRefinerMainModelConfig } from 'services/api/types'; import { isNonRefinerMainModelConfig } from 'services/api/types';
@ -109,7 +108,7 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
} }
} }
dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) }))); toast({ id: 'PARAMETER_SET', title: t('toast.parameterSet', { parameter: 'Default settings' }) });
} }
}, },
}); });

View File

@ -1,7 +1,8 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState'; import { parseify } from 'common/util/serialize';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation'; import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketGeneratorProgress } from 'services/events/actions'; import { socketGeneratorProgress } from 'services/events/actions';
@ -11,13 +12,14 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis
startAppListening({ startAppListening({
actionCreator: socketGeneratorProgress, actionCreator: socketGeneratorProgress,
effect: (action) => { effect: (action) => {
log.trace(action.payload, `Generator progress`); log.trace(parseify(action.payload), `Generator progress`);
const { source_node_id, step, total_steps, progress_image } = action.payload.data; const { invocation_source_id, step, total_steps, progress_image } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[source_node_id]); const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) { if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS; nes.status = zNodeStatus.enum.IN_PROGRESS;
nes.progress = (step + 1) / total_steps; nes.progress = (step + 1) / total_steps;
nes.progressImage = progress_image ?? null; nes.progressImage = progress_image ?? null;
upsertExecutionState(nes.nodeId, nes);
} }
}, },
}); });

View File

@ -1,14 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketGraphExecutionStateComplete } from 'services/events/actions';
const log = logger('socketio');
export const addGraphExecutionStateCompleteEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketGraphExecutionStateComplete,
effect: (action) => {
log.debug(action.payload, 'Session complete');
},
});
};

View File

@ -29,12 +29,12 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
actionCreator: socketInvocationComplete, actionCreator: socketInvocationComplete,
effect: async (action, { dispatch, getState }) => { effect: async (action, { dispatch, getState }) => {
const { data } = action.payload; const { data } = action.payload;
log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`); log.debug({ data: parseify(data) }, `Invocation complete (${data.invocation.type})`);
const { result, node, queue_batch_id, source_node_id } = data; const { result, invocation_source_id } = data;
// This complete event has an associated image output // This complete event has an associated image output
if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type)) { if (isImageOutput(data.result) && !nodeTypeDenylist.includes(data.invocation.type)) {
const { image_name } = result.image; const { image_name } = data.result.image;
const { canvas, gallery } = getState(); const { canvas, gallery } = getState();
// This populates the `getImageDTO` cache // This populates the `getImageDTO` cache
@ -48,7 +48,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
imageDTORequest.unsubscribe(); imageDTORequest.unsubscribe();
// Add canvas images to the staging area // Add canvas images to the staging area
if (canvas.batchIds.includes(queue_batch_id) && data.source_node_id === CANVAS_OUTPUT) { if (canvas.batchIds.includes(data.batch_id) && data.invocation_source_id === CANVAS_OUTPUT) {
dispatch(addImageToStagingArea(imageDTO)); dispatch(addImageToStagingArea(imageDTO));
} }
@ -114,7 +114,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
} }
} }
const nes = deepClone($nodeExecutionStates.get()[source_node_id]); const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) { if (nes) {
nes.status = zNodeStatus.enum.COMPLETED; nes.status = zNodeStatus.enum.COMPLETED;
if (nes.progress !== null) { if (nes.progress !== null) {

View File

@ -1,6 +1,7 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation'; import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketInvocationError } from 'services/events/actions'; import { socketInvocationError } from 'services/events/actions';
@ -11,14 +12,18 @@ export const addInvocationErrorEventListener = (startAppListening: AppStartListe
startAppListening({ startAppListening({
actionCreator: socketInvocationError, actionCreator: socketInvocationError,
effect: (action) => { effect: (action) => {
log.error(action.payload, `Invocation error (${action.payload.data.node.type})`); const { invocation_source_id, invocation, error_type, error_message, error_traceback } = action.payload.data;
const { source_node_id } = action.payload.data; log.error(parseify(action.payload), `Invocation error (${invocation.type})`);
const nes = deepClone($nodeExecutionStates.get()[source_node_id]); const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) { if (nes) {
nes.status = zNodeStatus.enum.FAILED; nes.status = zNodeStatus.enum.FAILED;
nes.error = action.payload.data.error;
nes.progress = null; nes.progress = null;
nes.progressImage = null; nes.progressImage = null;
nes.error = {
error_type,
error_message,
error_traceback,
};
upsertExecutionState(nes.nodeId, nes); upsertExecutionState(nes.nodeId, nes);
} }
}, },

View File

@ -1,14 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketInvocationRetrievalError } from 'services/events/actions';
const log = logger('socketio');
export const addInvocationRetrievalErrorEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketInvocationRetrievalError,
effect: (action) => {
log.error(action.payload, `Invocation retrieval error (${action.payload.data.graph_execution_state_id})`);
},
});
};

View File

@ -1,6 +1,7 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation'; import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketInvocationStarted } from 'services/events/actions'; import { socketInvocationStarted } from 'services/events/actions';
@ -11,9 +12,9 @@ export const addInvocationStartedEventListener = (startAppListening: AppStartLis
startAppListening({ startAppListening({
actionCreator: socketInvocationStarted, actionCreator: socketInvocationStarted,
effect: (action) => { effect: (action) => {
log.debug(action.payload, `Invocation started (${action.payload.data.node.type})`); log.debug(parseify(action.payload), `Invocation started (${action.payload.data.invocation.type})`);
const { source_node_id } = action.payload.data; const { invocation_source_id } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[source_node_id]); const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) { if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS; nes.status = zNodeStatus.enum.IN_PROGRESS;
upsertExecutionState(nes.nodeId, nes); upsertExecutionState(nes.nodeId, nes);

View File

@ -3,14 +3,14 @@ import { api, LIST_TAG } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models'; import { modelsApi } from 'services/api/endpoints/models';
import { import {
socketModelInstallCancelled, socketModelInstallCancelled,
socketModelInstallCompleted, socketModelInstallComplete,
socketModelInstallDownloading, socketModelInstallDownloadProgress,
socketModelInstallError, socketModelInstallError,
} from 'services/events/actions'; } from 'services/events/actions';
export const addModelInstallEventListener = (startAppListening: AppStartListening) => { export const addModelInstallEventListener = (startAppListening: AppStartListening) => {
startAppListening({ startAppListening({
actionCreator: socketModelInstallDownloading, actionCreator: socketModelInstallDownloadProgress,
effect: async (action, { dispatch }) => { effect: async (action, { dispatch }) => {
const { bytes, total_bytes, id } = action.payload.data; const { bytes, total_bytes, id } = action.payload.data;
@ -29,7 +29,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
}); });
startAppListening({ startAppListening({
actionCreator: socketModelInstallCompleted, actionCreator: socketModelInstallComplete,
effect: (action, { dispatch }) => { effect: (action, { dispatch }) => {
const { id } = action.payload.data; const { id } = action.payload.data;

View File

@ -1,6 +1,6 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketModelLoadCompleted, socketModelLoadStarted } from 'services/events/actions'; import { socketModelLoadComplete, socketModelLoadStarted } from 'services/events/actions';
const log = logger('socketio'); const log = logger('socketio');
@ -8,10 +8,11 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening)
startAppListening({ startAppListening({
actionCreator: socketModelLoadStarted, actionCreator: socketModelLoadStarted,
effect: (action) => { effect: (action) => {
const { model_config, submodel_type } = action.payload.data; const { config, submodel_type } = action.payload.data;
const { name, base, type } = model_config; const { name, base, type } = config;
const extras: string[] = [base, type]; const extras: string[] = [base, type];
if (submodel_type) { if (submodel_type) {
extras.push(submodel_type); extras.push(submodel_type);
} }
@ -23,10 +24,10 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening)
}); });
startAppListening({ startAppListening({
actionCreator: socketModelLoadCompleted, actionCreator: socketModelLoadComplete,
effect: (action) => { effect: (action) => {
const { model_config, submodel_type } = action.payload.data; const { config, submodel_type } = action.payload.data;
const { name, base, type } = model_config; const { name, base, type } = config;
const extras: string[] = [base, type]; const extras: string[] = [base, type];
if (submodel_type) { if (submodel_type) {

View File

@ -3,6 +3,8 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState'; import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation'; import { zNodeStatus } from 'features/nodes/types/invocation';
import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription';
import { toast } from 'features/toast/toast';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue'; import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
import { socketQueueItemStatusChanged } from 'services/events/actions'; import { socketQueueItemStatusChanged } from 'services/events/actions';
@ -12,18 +14,38 @@ const log = logger('socketio');
export const addSocketQueueItemStatusChangedEventListener = (startAppListening: AppStartListening) => { export const addSocketQueueItemStatusChangedEventListener = (startAppListening: AppStartListening) => {
startAppListening({ startAppListening({
actionCreator: socketQueueItemStatusChanged, actionCreator: socketQueueItemStatusChanged,
effect: async (action, { dispatch }) => { effect: async (action, { dispatch, getState }) => {
// we've got new status for the queue item, batch and queue // we've got new status for the queue item, batch and queue
const { queue_item, batch_status, queue_status } = action.payload.data; const {
item_id,
session_id,
status,
started_at,
updated_at,
completed_at,
batch_status,
queue_status,
error_type,
error_message,
error_traceback,
} = action.payload.data;
log.debug(action.payload, `Queue item ${queue_item.item_id} status updated: ${queue_item.status}`); log.debug(action.payload, `Queue item ${item_id} status updated: ${status}`);
// Update this specific queue item in the list of queue items (this is the queue item DTO, without the session) // Update this specific queue item in the list of queue items (this is the queue item DTO, without the session)
dispatch( dispatch(
queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => { queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => {
queueItemsAdapter.updateOne(draft, { queueItemsAdapter.updateOne(draft, {
id: String(queue_item.item_id), id: String(item_id),
changes: queue_item, changes: {
status,
started_at,
updated_at: updated_at ?? undefined,
completed_at: completed_at ?? undefined,
error_type,
error_message,
error_traceback,
},
}); });
}) })
); );
@ -43,23 +65,18 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
queueApi.util.updateQueryData('getBatchStatus', { batch_id: batch_status.batch_id }, () => batch_status) queueApi.util.updateQueryData('getBatchStatus', { batch_id: batch_status.batch_id }, () => batch_status)
); );
// Update the queue item status (this is the full queue item, including the session)
dispatch(
queueApi.util.updateQueryData('getQueueItem', queue_item.item_id, (draft) => {
if (!draft) {
return;
}
Object.assign(draft, queue_item);
})
);
// Invalidate caches for things we cannot update // Invalidate caches for things we cannot update
// TODO: technically, we could possibly update the current session queue item, but feels safer to just request it again // TODO: technically, we could possibly update the current session queue item, but feels safer to just request it again
dispatch( dispatch(
queueApi.util.invalidateTags(['CurrentSessionQueueItem', 'NextSessionQueueItem', 'InvocationCacheStatus']) queueApi.util.invalidateTags([
'CurrentSessionQueueItem',
'NextSessionQueueItem',
'InvocationCacheStatus',
{ type: 'SessionQueueItem', id: item_id },
])
); );
if (['in_progress'].includes(action.payload.data.queue_item.status)) { if (status === 'in_progress') {
forEach($nodeExecutionStates.get(), (nes) => { forEach($nodeExecutionStates.get(), (nes) => {
if (!nes) { if (!nes) {
return; return;
@ -72,6 +89,25 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
clone.outputs = []; clone.outputs = [];
$nodeExecutionStates.setKey(clone.nodeId, clone); $nodeExecutionStates.setKey(clone.nodeId, clone);
}); });
} else if (status === 'failed' && error_type) {
const isLocal = getState().config.isLocal ?? true;
const sessionId = session_id;
toast({
id: `INVOCATION_ERROR_${error_type}`,
title: getTitleFromErrorType(error_type),
status: 'error',
duration: null,
updateDescription: isLocal,
description: (
<ErrorToastDescription
errorType={error_type}
errorMessage={error_message}
sessionId={sessionId}
isLocal={isLocal}
/>
),
});
} }
}, },
}); });

View File

@ -1,14 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketSessionRetrievalError } from 'services/events/actions';
const log = logger('socketio');
export const addSessionRetrievalErrorEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketSessionRetrievalError,
effect: (action) => {
log.error(action.payload, `Session retrieval error (${action.payload.data.graph_execution_state_id})`);
},
});
};

View File

@ -1,14 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketSubscribedSession } from 'services/events/actions';
const log = logger('socketio');
export const addSocketSubscribedEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketSubscribedSession,
effect: (action) => {
log.debug(action.payload, 'Subscribed');
},
});
};

View File

@ -1,13 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketUnsubscribedSession } from 'services/events/actions';
const log = logger('socketio');
export const addSocketUnsubscribedEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketUnsubscribedSession,
effect: (action) => {
log.debug(action.payload, 'Unsubscribed');
},
});
};

View File

@ -1,6 +1,6 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { stagingAreaImageSaved } from 'features/canvas/store/actions'; import { stagingAreaImageSaved } from 'features/canvas/store/actions';
import { addToast } from 'features/system/store/systemSlice'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
@ -29,15 +29,14 @@ export const addStagingAreaImageSavedListener = (startAppListening: AppStartList
}) })
); );
} }
dispatch(addToast({ title: t('toast.imageSaved'), status: 'success' })); toast({ id: 'IMAGE_SAVED', title: t('toast.imageSaved'), status: 'success' });
} catch (error) { } catch (error) {
dispatch( toast({
addToast({ id: 'IMAGE_SAVE_FAILED',
title: t('toast.imageSavingFailed'), title: t('toast.imageSavingFailed'),
description: (error as Error)?.message, description: (error as Error)?.message,
status: 'error', status: 'error',
}) });
);
} }
}, },
}); });

View File

@ -1,12 +1,11 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { updateAllNodesRequested } from 'features/nodes/store/actions'; import { updateAllNodesRequested } from 'features/nodes/store/actions';
import { $templates, nodeReplaced } from 'features/nodes/store/nodesSlice'; import { $templates, nodesChanged } from 'features/nodes/store/nodesSlice';
import { NodeUpdateError } from 'features/nodes/types/error'; import { NodeUpdateError } from 'features/nodes/types/error';
import { isInvocationNode } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation';
import { getNeedsUpdate, updateNode } from 'features/nodes/util/node/nodeUpdate'; import { getNeedsUpdate, updateNode } from 'features/nodes/util/node/nodeUpdate';
import { addToast } from 'features/system/store/systemSlice'; import { toast } from 'features/toast/toast';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next'; import { t } from 'i18next';
export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartListening) => { export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartListening) => {
@ -31,7 +30,12 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi
} }
try { try {
const updatedNode = updateNode(node, template); const updatedNode = updateNode(node, template);
dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode })); dispatch(
nodesChanged([
{ type: 'remove', id: updatedNode.id },
{ type: 'add', item: updatedNode },
])
);
} catch (e) { } catch (e) {
if (e instanceof NodeUpdateError) { if (e instanceof NodeUpdateError) {
unableToUpdateCount++; unableToUpdateCount++;
@ -45,24 +49,18 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi
count: unableToUpdateCount, count: unableToUpdateCount,
}) })
); );
dispatch( toast({
addToast( id: 'UNABLE_TO_UPDATE_NODES',
makeToast({ title: t('nodes.unableToUpdateNodes', {
title: t('nodes.unableToUpdateNodes', { count: unableToUpdateCount,
count: unableToUpdateCount, }),
}), });
})
)
);
} else { } else {
dispatch( toast({
addToast( id: 'ALL_NODES_UPDATED',
makeToast({ title: t('nodes.allNodesUpdated'),
title: t('nodes.allNodesUpdated'), status: 'success',
status: 'success', });
})
)
);
} }
}, },
}); });

View File

@ -4,7 +4,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { parseify } from 'common/util/serialize'; import { parseify } from 'common/util/serialize';
import { buildAdHocUpscaleGraph } from 'features/nodes/util/graph/buildAdHocUpscaleGraph'; import { buildAdHocUpscaleGraph } from 'features/nodes/util/graph/buildAdHocUpscaleGraph';
import { createIsAllowedToUpscaleSelector } from 'features/parameters/hooks/useIsAllowedToUpscale'; import { createIsAllowedToUpscaleSelector } from 'features/parameters/hooks/useIsAllowedToUpscale';
import { addToast } from 'features/system/store/systemSlice'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue'; import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig, ImageDTO } from 'services/api/types'; import type { BatchConfig, ImageDTO } from 'services/api/types';
@ -29,12 +29,11 @@ export const addUpscaleRequestedListener = (startAppListening: AppStartListening
{ imageDTO }, { imageDTO },
t(detailTKey ?? 'parameters.isAllowedToUpscale.tooLarge') // should never coalesce t(detailTKey ?? 'parameters.isAllowedToUpscale.tooLarge') // should never coalesce
); );
dispatch( toast({
addToast({ id: 'NOT_ALLOWED_TO_UPSCALE',
title: t(detailTKey ?? 'parameters.isAllowedToUpscale.tooLarge'), // should never coalesce title: t(detailTKey ?? 'parameters.isAllowedToUpscale.tooLarge'), // should never coalesce
status: 'error', status: 'error',
}) });
);
return; return;
} }
@ -65,12 +64,11 @@ export const addUpscaleRequestedListener = (startAppListening: AppStartListening
if (error instanceof Object && 'status' in error && error.status === 403) { if (error instanceof Object && 'status' in error && error.status === 403) {
return; return;
} else { } else {
dispatch( toast({
addToast({ id: 'GRAPH_QUEUE_FAILED',
title: t('queue.graphFailedToQueue'), title: t('queue.graphFailedToQueue'),
status: 'error', status: 'error',
}) });
);
} }
} }
}, },

View File

@ -8,23 +8,23 @@ import type { Templates } from 'features/nodes/store/types';
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error'; import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow'; import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow';
import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow'; import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow';
import { addToast } from 'features/system/store/systemSlice'; import { toast } from 'features/toast/toast';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next'; import { t } from 'i18next';
import { checkBoardAccess, checkImageAccess, checkModelAccess } from 'services/api/hooks/accessChecks';
import type { GraphAndWorkflowResponse, NonNullableGraph } from 'services/api/types'; import type { GraphAndWorkflowResponse, NonNullableGraph } from 'services/api/types';
import { z } from 'zod'; import { z } from 'zod';
import { fromZodError } from 'zod-validation-error'; import { fromZodError } from 'zod-validation-error';
const getWorkflow = (data: GraphAndWorkflowResponse, templates: Templates) => { const getWorkflow = async (data: GraphAndWorkflowResponse, templates: Templates) => {
if (data.workflow) { if (data.workflow) {
// Prefer to load the workflow if it's available - it has more information // Prefer to load the workflow if it's available - it has more information
const parsed = JSON.parse(data.workflow); const parsed = JSON.parse(data.workflow);
return validateWorkflow(parsed, templates); return await validateWorkflow(parsed, templates, checkImageAccess, checkBoardAccess, checkModelAccess);
} else if (data.graph) { } else if (data.graph) {
// Else we fall back on the graph, using the graphToWorkflow function to convert and do layout // Else we fall back on the graph, using the graphToWorkflow function to convert and do layout
const parsed = JSON.parse(data.graph); const parsed = JSON.parse(data.graph);
const workflow = graphToWorkflow(parsed as NonNullableGraph, true); const workflow = graphToWorkflow(parsed as NonNullableGraph, true);
return validateWorkflow(workflow, templates); return await validateWorkflow(workflow, templates, checkImageAccess, checkBoardAccess, checkModelAccess);
} else { } else {
throw new Error('No workflow or graph provided'); throw new Error('No workflow or graph provided');
} }
@ -33,13 +33,13 @@ const getWorkflow = (data: GraphAndWorkflowResponse, templates: Templates) => {
export const addWorkflowLoadRequestedListener = (startAppListening: AppStartListening) => { export const addWorkflowLoadRequestedListener = (startAppListening: AppStartListening) => {
startAppListening({ startAppListening({
actionCreator: workflowLoadRequested, actionCreator: workflowLoadRequested,
effect: (action, { dispatch }) => { effect: async (action, { dispatch }) => {
const log = logger('nodes'); const log = logger('nodes');
const { data, asCopy } = action.payload; const { data, asCopy } = action.payload;
const nodeTemplates = $templates.get(); const nodeTemplates = $templates.get();
try { try {
const { workflow, warnings } = getWorkflow(data, nodeTemplates); const { workflow, warnings } = await getWorkflow(data, nodeTemplates);
if (asCopy) { if (asCopy) {
// If we're loading a copy, we need to remove the ID so that the backend will create a new workflow // If we're loading a copy, we need to remove the ID so that the backend will create a new workflow
@ -48,23 +48,18 @@ export const addWorkflowLoadRequestedListener = (startAppListening: AppStartList
dispatch(workflowLoaded(workflow)); dispatch(workflowLoaded(workflow));
if (!warnings.length) { if (!warnings.length) {
dispatch( toast({
addToast( id: 'WORKFLOW_LOADED',
makeToast({ title: t('toast.workflowLoaded'),
title: t('toast.workflowLoaded'), status: 'success',
status: 'success', });
})
)
);
} else { } else {
dispatch( toast({
addToast( id: 'WORKFLOW_LOADED',
makeToast({ title: t('toast.loadedWithWarnings'),
title: t('toast.loadedWithWarnings'), status: 'warning',
status: 'warning', });
})
)
);
warnings.forEach(({ message, ...rest }) => { warnings.forEach(({ message, ...rest }) => {
log.warn(rest, message); log.warn(rest, message);
}); });
@ -77,54 +72,42 @@ export const addWorkflowLoadRequestedListener = (startAppListening: AppStartList
if (e instanceof WorkflowVersionError) { if (e instanceof WorkflowVersionError) {
// The workflow version was not recognized in the valid list of versions // The workflow version was not recognized in the valid list of versions
log.error({ error: parseify(e) }, e.message); log.error({ error: parseify(e) }, e.message);
dispatch( toast({
addToast( id: 'UNABLE_TO_VALIDATE_WORKFLOW',
makeToast({ title: t('nodes.unableToValidateWorkflow'),
title: t('nodes.unableToValidateWorkflow'), status: 'error',
status: 'error', description: e.message,
description: e.message, });
})
)
);
} else if (e instanceof WorkflowMigrationError) { } else if (e instanceof WorkflowMigrationError) {
// There was a problem migrating the workflow to the latest version // There was a problem migrating the workflow to the latest version
log.error({ error: parseify(e) }, e.message); log.error({ error: parseify(e) }, e.message);
dispatch( toast({
addToast( id: 'UNABLE_TO_VALIDATE_WORKFLOW',
makeToast({ title: t('nodes.unableToValidateWorkflow'),
title: t('nodes.unableToValidateWorkflow'), status: 'error',
status: 'error', description: e.message,
description: e.message, });
})
)
);
} else if (e instanceof z.ZodError) { } else if (e instanceof z.ZodError) {
// There was a problem validating the workflow itself // There was a problem validating the workflow itself
const { message } = fromZodError(e, { const { message } = fromZodError(e, {
prefix: t('nodes.workflowValidation'), prefix: t('nodes.workflowValidation'),
}); });
log.error({ error: parseify(e) }, message); log.error({ error: parseify(e) }, message);
dispatch( toast({
addToast( id: 'UNABLE_TO_VALIDATE_WORKFLOW',
makeToast({ title: t('nodes.unableToValidateWorkflow'),
title: t('nodes.unableToValidateWorkflow'), status: 'error',
status: 'error', description: message,
description: message, });
})
)
);
} else { } else {
// Some other error occurred // Some other error occurred
log.error({ error: parseify(e) }, t('nodes.unknownErrorValidatingWorkflow')); log.error({ error: parseify(e) }, t('nodes.unknownErrorValidatingWorkflow'));
dispatch( toast({
addToast( id: 'UNABLE_TO_VALIDATE_WORKFLOW',
makeToast({ title: t('nodes.unableToValidateWorkflow'),
title: t('nodes.unableToValidateWorkflow'), status: 'error',
status: 'error', description: t('nodes.unknownErrorValidatingWorkflow'),
description: t('nodes.unknownErrorValidatingWorkflow'), });
})
)
);
} }
} }
}, },

View File

@ -74,6 +74,7 @@ export type AppConfig = {
maxUpscalePixels?: number; maxUpscalePixels?: number;
metadataFetchDebounce?: number; metadataFetchDebounce?: number;
workflowFetchDebounce?: number; workflowFetchDebounce?: number;
isLocal?: boolean;
sd: { sd: {
defaultModel?: string; defaultModel?: string;
disabledControlNetModels: string[]; disabledControlNetModels: string[];

View File

@ -1,11 +1,10 @@
import { useAppToaster } from 'app/components/Toaster';
import { useImageUrlToBlob } from 'common/hooks/useImageUrlToBlob'; import { useImageUrlToBlob } from 'common/hooks/useImageUrlToBlob';
import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard'; import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard';
import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
export const useCopyImageToClipboard = () => { export const useCopyImageToClipboard = () => {
const toaster = useAppToaster();
const { t } = useTranslation(); const { t } = useTranslation();
const imageUrlToBlob = useImageUrlToBlob(); const imageUrlToBlob = useImageUrlToBlob();
@ -16,12 +15,11 @@ export const useCopyImageToClipboard = () => {
const copyImageToClipboard = useCallback( const copyImageToClipboard = useCallback(
async (image_url: string) => { async (image_url: string) => {
if (!isClipboardAPIAvailable) { if (!isClipboardAPIAvailable) {
toaster({ toast({
id: 'PROBLEM_COPYING_IMAGE',
title: t('toast.problemCopyingImage'), title: t('toast.problemCopyingImage'),
description: "Your browser doesn't support the Clipboard API.", description: "Your browser doesn't support the Clipboard API.",
status: 'error', status: 'error',
duration: 2500,
isClosable: true,
}); });
} }
try { try {
@ -33,23 +31,21 @@ export const useCopyImageToClipboard = () => {
copyBlobToClipboard(blob); copyBlobToClipboard(blob);
toaster({ toast({
id: 'IMAGE_COPIED',
title: t('toast.imageCopied'), title: t('toast.imageCopied'),
status: 'success', status: 'success',
duration: 2500,
isClosable: true,
}); });
} catch (err) { } catch (err) {
toaster({ toast({
id: 'PROBLEM_COPYING_IMAGE',
title: t('toast.problemCopyingImage'), title: t('toast.problemCopyingImage'),
description: String(err), description: String(err),
status: 'error', status: 'error',
duration: 2500,
isClosable: true,
}); });
} }
}, },
[imageUrlToBlob, isClipboardAPIAvailable, t, toaster] [imageUrlToBlob, isClipboardAPIAvailable, t]
); );
return { isClipboardAPIAvailable, copyImageToClipboard }; return { isClipboardAPIAvailable, copyImageToClipboard };

View File

@ -1,13 +1,12 @@
import { useStore } from '@nanostores/react'; import { useStore } from '@nanostores/react';
import { useAppToaster } from 'app/components/Toaster';
import { $authToken } from 'app/store/nanostores/authToken'; import { $authToken } from 'app/store/nanostores/authToken';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { imageDownloaded } from 'features/gallery/store/actions'; import { imageDownloaded } from 'features/gallery/store/actions';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
export const useDownloadImage = () => { export const useDownloadImage = () => {
const toaster = useAppToaster();
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const authToken = useStore($authToken); const authToken = useStore($authToken);
@ -37,16 +36,15 @@ export const useDownloadImage = () => {
window.URL.revokeObjectURL(url); window.URL.revokeObjectURL(url);
dispatch(imageDownloaded()); dispatch(imageDownloaded());
} catch (err) { } catch (err) {
toaster({ toast({
id: 'PROBLEM_DOWNLOADING_IMAGE',
title: t('toast.problemDownloadingImage'), title: t('toast.problemDownloadingImage'),
description: String(err), description: String(err),
status: 'error', status: 'error',
duration: 2500,
isClosable: true,
}); });
} }
}, },
[t, toaster, dispatch, authToken] [t, dispatch, authToken]
); );
return { downloadImage }; return { downloadImage };

View File

@ -1,6 +1,6 @@
import { useAppToaster } from 'app/components/Toaster';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { toast } from 'features/toast/toast';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useCallback, useEffect, useState } from 'react'; import { useCallback, useEffect, useState } from 'react';
import type { Accept, FileRejection } from 'react-dropzone'; import type { Accept, FileRejection } from 'react-dropzone';
@ -26,7 +26,6 @@ const selectPostUploadAction = createMemoizedSelector(activeTabNameSelector, (ac
export const useFullscreenDropzone = () => { export const useFullscreenDropzone = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const toaster = useAppToaster();
const postUploadAction = useAppSelector(selectPostUploadAction); const postUploadAction = useAppSelector(selectPostUploadAction);
const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId); const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId);
const [isHandlingUpload, setIsHandlingUpload] = useState<boolean>(false); const [isHandlingUpload, setIsHandlingUpload] = useState<boolean>(false);
@ -37,13 +36,14 @@ export const useFullscreenDropzone = () => {
(rejection: FileRejection) => { (rejection: FileRejection) => {
setIsHandlingUpload(true); setIsHandlingUpload(true);
toaster({ toast({
id: 'UPLOAD_FAILED',
title: t('toast.uploadFailed'), title: t('toast.uploadFailed'),
description: rejection.errors.map((error) => error.message).join('\n'), description: rejection.errors.map((error) => error.message).join('\n'),
status: 'error', status: 'error',
}); });
}, },
[t, toaster] [t]
); );
const fileAcceptedCallback = useCallback( const fileAcceptedCallback = useCallback(
@ -62,7 +62,8 @@ export const useFullscreenDropzone = () => {
const onDrop = useCallback( const onDrop = useCallback(
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => { (acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
if (fileRejections.length > 1) { if (fileRejections.length > 1) {
toaster({ toast({
id: 'UPLOAD_FAILED',
title: t('toast.uploadFailed'), title: t('toast.uploadFailed'),
description: t('toast.uploadFailedInvalidUploadDesc'), description: t('toast.uploadFailedInvalidUploadDesc'),
status: 'error', status: 'error',
@ -78,7 +79,7 @@ export const useFullscreenDropzone = () => {
fileAcceptedCallback(file); fileAcceptedCallback(file);
}); });
}, },
[t, toaster, fileAcceptedCallback, fileRejectionCallback] [t, fileAcceptedCallback, fileRejectionCallback]
); );
const onDragOver = useCallback(() => { const onDragOver = useCallback(() => {

View File

@ -137,7 +137,7 @@ const createSelector = (templates: Templates) =>
if (l.controlAdapter.type === 't2i_adapter') { if (l.controlAdapter.type === 't2i_adapter') {
const multiple = model?.base === 'sdxl' ? 32 : 64; const multiple = model?.base === 'sdxl' ? 32 : 64;
if (size.width % multiple !== 0 || size.height % multiple !== 0) { if (size.width % multiple !== 0 || size.height % multiple !== 0) {
problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions')); problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions', { multiple }));
} }
} }
} }

View File

@ -1,6 +0,0 @@
import { createStandaloneToast, theme, TOAST_OPTIONS } from '@invoke-ai/ui-library';
export const { toast } = createStandaloneToast({
theme: theme,
defaultOptions: TOAST_OPTIONS.defaultOptions,
});

View File

@ -613,7 +613,7 @@ export const canvasSlice = createSlice({
state.batchIds = state.batchIds.filter((id) => id !== batch_status.batch_id); state.batchIds = state.batchIds.filter((id) => id !== batch_status.batch_id);
} }
const queueItemStatus = action.payload.data.queue_item.status; const queueItemStatus = action.payload.data.status;
if (queueItemStatus === 'canceled' || queueItemStatus === 'failed') { if (queueItemStatus === 'canceled' || queueItemStatus === 'failed') {
resetStagingAreaIfEmpty(state); resetStagingAreaIfEmpty(state);
} }

View File

@ -4,7 +4,7 @@ import { CALayerControlAdapterWrapper } from 'features/controlLayers/components/
import { LayerDeleteButton } from 'features/controlLayers/components/LayerCommon/LayerDeleteButton'; import { LayerDeleteButton } from 'features/controlLayers/components/LayerCommon/LayerDeleteButton';
import { LayerMenu } from 'features/controlLayers/components/LayerCommon/LayerMenu'; import { LayerMenu } from 'features/controlLayers/components/LayerCommon/LayerMenu';
import { LayerTitle } from 'features/controlLayers/components/LayerCommon/LayerTitle'; import { LayerTitle } from 'features/controlLayers/components/LayerCommon/LayerTitle';
import { LayerVisibilityToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle'; import { LayerIsEnabledToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle';
import { LayerWrapper } from 'features/controlLayers/components/LayerCommon/LayerWrapper'; import { LayerWrapper } from 'features/controlLayers/components/LayerCommon/LayerWrapper';
import { layerSelected, selectCALayerOrThrow } from 'features/controlLayers/store/controlLayersSlice'; import { layerSelected, selectCALayerOrThrow } from 'features/controlLayers/store/controlLayersSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
@ -26,7 +26,7 @@ export const CALayer = memo(({ layerId }: Props) => {
return ( return (
<LayerWrapper onClick={onClick} borderColor={isSelected ? 'base.400' : 'base.800'}> <LayerWrapper onClick={onClick} borderColor={isSelected ? 'base.400' : 'base.800'}>
<Flex gap={3} alignItems="center" p={3} cursor="pointer" onDoubleClick={onToggle}> <Flex gap={3} alignItems="center" p={3} cursor="pointer" onDoubleClick={onToggle}>
<LayerVisibilityToggle layerId={layerId} /> <LayerIsEnabledToggle layerId={layerId} />
<LayerTitle type="control_adapter_layer" /> <LayerTitle type="control_adapter_layer" />
<Spacer /> <Spacer />
<CALayerOpacity layerId={layerId} /> <CALayerOpacity layerId={layerId} />

View File

@ -5,7 +5,7 @@ import { InitialImagePreview } from 'features/controlLayers/components/IILayer/I
import { LayerDeleteButton } from 'features/controlLayers/components/LayerCommon/LayerDeleteButton'; import { LayerDeleteButton } from 'features/controlLayers/components/LayerCommon/LayerDeleteButton';
import { LayerMenu } from 'features/controlLayers/components/LayerCommon/LayerMenu'; import { LayerMenu } from 'features/controlLayers/components/LayerCommon/LayerMenu';
import { LayerTitle } from 'features/controlLayers/components/LayerCommon/LayerTitle'; import { LayerTitle } from 'features/controlLayers/components/LayerCommon/LayerTitle';
import { LayerVisibilityToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle'; import { LayerIsEnabledToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle';
import { LayerWrapper } from 'features/controlLayers/components/LayerCommon/LayerWrapper'; import { LayerWrapper } from 'features/controlLayers/components/LayerCommon/LayerWrapper';
import { import {
iiLayerDenoisingStrengthChanged, iiLayerDenoisingStrengthChanged,
@ -66,7 +66,7 @@ export const IILayer = memo(({ layerId }: Props) => {
return ( return (
<LayerWrapper onClick={onClick} borderColor={layer.isSelected ? 'base.400' : 'base.800'}> <LayerWrapper onClick={onClick} borderColor={layer.isSelected ? 'base.400' : 'base.800'}>
<Flex gap={3} alignItems="center" p={3} cursor="pointer" onDoubleClick={onToggle}> <Flex gap={3} alignItems="center" p={3} cursor="pointer" onDoubleClick={onToggle}>
<LayerVisibilityToggle layerId={layerId} /> <LayerIsEnabledToggle layerId={layerId} />
<LayerTitle type="initial_image_layer" /> <LayerTitle type="initial_image_layer" />
<Spacer /> <Spacer />
<IILayerOpacity layerId={layerId} /> <IILayerOpacity layerId={layerId} />

View File

@ -3,7 +3,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { IPALayerIPAdapterWrapper } from 'features/controlLayers/components/IPALayer/IPALayerIPAdapterWrapper'; import { IPALayerIPAdapterWrapper } from 'features/controlLayers/components/IPALayer/IPALayerIPAdapterWrapper';
import { LayerDeleteButton } from 'features/controlLayers/components/LayerCommon/LayerDeleteButton'; import { LayerDeleteButton } from 'features/controlLayers/components/LayerCommon/LayerDeleteButton';
import { LayerTitle } from 'features/controlLayers/components/LayerCommon/LayerTitle'; import { LayerTitle } from 'features/controlLayers/components/LayerCommon/LayerTitle';
import { LayerVisibilityToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle'; import { LayerIsEnabledToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle';
import { LayerWrapper } from 'features/controlLayers/components/LayerCommon/LayerWrapper'; import { LayerWrapper } from 'features/controlLayers/components/LayerCommon/LayerWrapper';
import { layerSelected, selectIPALayerOrThrow } from 'features/controlLayers/store/controlLayersSlice'; import { layerSelected, selectIPALayerOrThrow } from 'features/controlLayers/store/controlLayersSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
@ -22,7 +22,7 @@ export const IPALayer = memo(({ layerId }: Props) => {
return ( return (
<LayerWrapper onClick={onClick} borderColor={isSelected ? 'base.400' : 'base.800'}> <LayerWrapper onClick={onClick} borderColor={isSelected ? 'base.400' : 'base.800'}>
<Flex gap={3} alignItems="center" p={3} cursor="pointer" onDoubleClick={onToggle}> <Flex gap={3} alignItems="center" p={3} cursor="pointer" onDoubleClick={onToggle}>
<LayerVisibilityToggle layerId={layerId} /> <LayerIsEnabledToggle layerId={layerId} />
<LayerTitle type="ip_adapter_layer" /> <LayerTitle type="ip_adapter_layer" />
<Spacer /> <Spacer />
<LayerDeleteButton layerId={layerId} /> <LayerDeleteButton layerId={layerId} />

View File

@ -1,8 +1,8 @@
import { IconButton } from '@invoke-ai/ui-library'; import { IconButton } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { stopPropagation } from 'common/util/stopPropagation'; import { stopPropagation } from 'common/util/stopPropagation';
import { useLayerIsVisible } from 'features/controlLayers/hooks/layerStateHooks'; import { useLayerIsEnabled } from 'features/controlLayers/hooks/layerStateHooks';
import { layerVisibilityToggled } from 'features/controlLayers/store/controlLayersSlice'; import { layerIsEnabledToggled } from 'features/controlLayers/store/controlLayersSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { PiCheckBold } from 'react-icons/pi'; import { PiCheckBold } from 'react-icons/pi';
@ -11,21 +11,21 @@ type Props = {
layerId: string; layerId: string;
}; };
export const LayerVisibilityToggle = memo(({ layerId }: Props) => { export const LayerIsEnabledToggle = memo(({ layerId }: Props) => {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const isVisible = useLayerIsVisible(layerId); const isEnabled = useLayerIsEnabled(layerId);
const onClick = useCallback(() => { const onClick = useCallback(() => {
dispatch(layerVisibilityToggled(layerId)); dispatch(layerIsEnabledToggled(layerId));
}, [dispatch, layerId]); }, [dispatch, layerId]);
return ( return (
<IconButton <IconButton
size="sm" size="sm"
aria-label={t('controlLayers.toggleVisibility')} aria-label={t(isEnabled ? 'common.enabled' : 'common.disabled')}
tooltip={t('controlLayers.toggleVisibility')} tooltip={t(isEnabled ? 'common.enabled' : 'common.disabled')}
variant="outline" variant="outline"
icon={isVisible ? <PiCheckBold /> : undefined} icon={isEnabled ? <PiCheckBold /> : undefined}
onClick={onClick} onClick={onClick}
colorScheme="base" colorScheme="base"
onDoubleClick={stopPropagation} // double click expands the layer onDoubleClick={stopPropagation} // double click expands the layer
@ -33,4 +33,4 @@ export const LayerVisibilityToggle = memo(({ layerId }: Props) => {
); );
}); });
LayerVisibilityToggle.displayName = 'LayerVisibilityToggle'; LayerIsEnabledToggle.displayName = 'LayerVisibilityToggle';

View File

@ -6,7 +6,7 @@ import { AddPromptButtons } from 'features/controlLayers/components/AddPromptBut
import { LayerDeleteButton } from 'features/controlLayers/components/LayerCommon/LayerDeleteButton'; import { LayerDeleteButton } from 'features/controlLayers/components/LayerCommon/LayerDeleteButton';
import { LayerMenu } from 'features/controlLayers/components/LayerCommon/LayerMenu'; import { LayerMenu } from 'features/controlLayers/components/LayerCommon/LayerMenu';
import { LayerTitle } from 'features/controlLayers/components/LayerCommon/LayerTitle'; import { LayerTitle } from 'features/controlLayers/components/LayerCommon/LayerTitle';
import { LayerVisibilityToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle'; import { LayerIsEnabledToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle';
import { LayerWrapper } from 'features/controlLayers/components/LayerCommon/LayerWrapper'; import { LayerWrapper } from 'features/controlLayers/components/LayerCommon/LayerWrapper';
import { import {
isRegionalGuidanceLayer, isRegionalGuidanceLayer,
@ -55,7 +55,7 @@ export const RGLayer = memo(({ layerId }: Props) => {
return ( return (
<LayerWrapper onClick={onClick} borderColor={isSelected ? color : 'base.800'}> <LayerWrapper onClick={onClick} borderColor={isSelected ? color : 'base.800'}>
<Flex gap={3} alignItems="center" p={3} cursor="pointer" onDoubleClick={onToggle}> <Flex gap={3} alignItems="center" p={3} cursor="pointer" onDoubleClick={onToggle}>
<LayerVisibilityToggle layerId={layerId} /> <LayerIsEnabledToggle layerId={layerId} />
<LayerTitle type="regional_guidance_layer" /> <LayerTitle type="regional_guidance_layer" />
<Spacer /> <Spacer />
{autoNegative === 'invert' && ( {autoNegative === 'invert' && (

View File

@ -45,7 +45,6 @@ export const RGLayerNegativePrompt = memo(({ layerId }: Props) => {
variant="darkFilled" variant="darkFilled"
paddingRight={30} paddingRight={30}
fontSize="sm" fontSize="sm"
spellCheck={false}
/> />
<PromptOverlayButtonWrapper> <PromptOverlayButtonWrapper>
<RGLayerPromptDeleteButton layerId={layerId} polarity="negative" /> <RGLayerPromptDeleteButton layerId={layerId} polarity="negative" />

View File

@ -45,7 +45,6 @@ export const RGLayerPositivePrompt = memo(({ layerId }: Props) => {
variant="darkFilled" variant="darkFilled"
paddingRight={30} paddingRight={30}
minH={28} minH={28}
spellCheck={false}
/> />
<PromptOverlayButtonWrapper> <PromptOverlayButtonWrapper>
<RGLayerPromptDeleteButton layerId={layerId} polarity="positive" /> <RGLayerPromptDeleteButton layerId={layerId} polarity="positive" />

Some files were not shown because too many files have changed in this diff Show More