mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
63 Commits
v3.5.1
...
test/node-
Author | SHA1 | Date | |
---|---|---|---|
787df67ceb | |||
4b149ab521 | |||
25ecf08962 | |||
f051daea4f | |||
143d7e03ef | |||
9d959b8cf3 | |||
1b3a6f4540 | |||
3623112807 | |||
8c65ade392 | |||
cba25766aa | |||
c2177c1778 | |||
7f2f085658 | |||
d9c816bdbb | |||
b86b72437f | |||
508c7ca9eb | |||
e7824ed176 | |||
0deb588a02 | |||
567a19e47a | |||
c572e0a6a0 | |||
c8869f543c | |||
8ec5e07011 | |||
6ba83350ff | |||
5ec405ebe6 | |||
e8ac82a492 | |||
0c5bafdeb6 | |||
507a429d42 | |||
7a6ea8a67f | |||
0738bcfe9b | |||
bdc7227b61 | |||
d93d5561b1 | |||
aab7c2c152 | |||
593d91815d | |||
cd9f0e026f | |||
7a1fe7548b | |||
e1b8874bc5 | |||
d2f102b6ab | |||
b039fb1e78 | |||
91cdccd217 | |||
968bc41bcc | |||
e44fbd0d53 | |||
bf7780079e | |||
385a8afacf | |||
fc0a2ddef3 | |||
b6dea0d3b5 | |||
e720c2cf19 | |||
68cef6d90a | |||
77150ab7cd | |||
916404745c | |||
a786615783 | |||
430e9346e6 | |||
2cd3bd8234 | |||
89c01547cb | |||
eb65c12e61 | |||
ab5c3ed189 | |||
b0e3791a80 | |||
1dd6b3a508 | |||
94e58a2254 | |||
d6aa55b965 | |||
184a6cd85f | |||
e2924067ef | |||
21d1e76ea9 | |||
bdeab50a82 | |||
40ff9ce672 |
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
|
||||||
from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage
|
from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage
|
||||||
@ -9,7 +10,10 @@ from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
|||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||||
|
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||||
from invokeai.app.services.resource_name import SimpleNameService
|
from invokeai.app.services.resource_name import SimpleNameService
|
||||||
|
from invokeai.app.services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||||
|
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||||
from invokeai.app.services.urls import LocalUrlService
|
from invokeai.app.services.urls import LocalUrlService
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from invokeai.version.invokeai_version import __version__
|
from invokeai.version.invokeai_version import __version__
|
||||||
@ -25,6 +29,7 @@ from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsSto
|
|||||||
from ..services.model_manager_service import ModelManagerService
|
from ..services.model_manager_service import ModelManagerService
|
||||||
from ..services.processor import DefaultInvocationProcessor
|
from ..services.processor import DefaultInvocationProcessor
|
||||||
from ..services.sqlite import SqliteItemStorage
|
from ..services.sqlite import SqliteItemStorage
|
||||||
|
from ..services.thread import lock
|
||||||
from .events import FastAPIEventService
|
from .events import FastAPIEventService
|
||||||
|
|
||||||
|
|
||||||
@ -63,22 +68,32 @@ class ApiDependencies:
|
|||||||
output_folder = config.output_path
|
output_folder = config.output_path
|
||||||
|
|
||||||
# TODO: build a file/path manager?
|
# TODO: build a file/path manager?
|
||||||
db_path = config.db_path
|
if config.use_memory_db:
|
||||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
db_location = ":memory:"
|
||||||
db_location = str(db_path)
|
else:
|
||||||
|
db_path = config.db_path
|
||||||
|
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
db_location = str(db_path)
|
||||||
|
|
||||||
|
logger.info(f"Using database at {db_location}")
|
||||||
|
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
|
||||||
|
|
||||||
|
if config.log_sql:
|
||||||
|
db_conn.set_trace_callback(print)
|
||||||
|
db_conn.execute("PRAGMA foreign_keys = ON;")
|
||||||
|
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||||
filename=db_location, table_name="graph_executions"
|
conn=db_conn, table_name="graph_executions", lock=lock
|
||||||
)
|
)
|
||||||
|
|
||||||
urls = LocalUrlService()
|
urls = LocalUrlService()
|
||||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
image_record_storage = SqliteImageRecordStorage(conn=db_conn, lock=lock)
|
||||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
names = SimpleNameService()
|
names = SimpleNameService()
|
||||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||||
|
|
||||||
board_record_storage = SqliteBoardRecordStorage(db_location)
|
board_record_storage = SqliteBoardRecordStorage(conn=db_conn, lock=lock)
|
||||||
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn, lock=lock)
|
||||||
|
|
||||||
boards = BoardService(
|
boards = BoardService(
|
||||||
services=BoardServiceDependencies(
|
services=BoardServiceDependencies(
|
||||||
@ -120,18 +135,29 @@ class ApiDependencies:
|
|||||||
boards=boards,
|
boards=boards,
|
||||||
board_images=board_images,
|
board_images=board_images,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, lock=lock, table_name="graphs"),
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
configuration=config,
|
configuration=config,
|
||||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
session_queue=SqliteSessionQueue(conn=db_conn, lock=lock),
|
||||||
|
session_processor=DefaultSessionProcessor(),
|
||||||
|
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size),
|
||||||
)
|
)
|
||||||
|
|
||||||
create_system_graphs(services.graph_library)
|
create_system_graphs(services.graph_library)
|
||||||
|
|
||||||
ApiDependencies.invoker = Invoker(services)
|
ApiDependencies.invoker = Invoker(services)
|
||||||
|
|
||||||
|
try:
|
||||||
|
lock.acquire()
|
||||||
|
db_conn.execute("VACUUM;")
|
||||||
|
db_conn.commit()
|
||||||
|
logger.info("Cleaned database")
|
||||||
|
finally:
|
||||||
|
lock.release()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def shutdown():
|
def shutdown():
|
||||||
if ApiDependencies.invoker:
|
if ApiDependencies.invoker:
|
||||||
|
247
invokeai/app/api/routers/session_queue.py
Normal file
247
invokeai/app/api/routers/session_queue.py
Normal file
@ -0,0 +1,247 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import Body, Path, Query
|
||||||
|
from fastapi.routing import APIRouter
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
||||||
|
from invokeai.app.services.session_queue.session_queue_common import (
|
||||||
|
QUEUE_ITEM_STATUS,
|
||||||
|
Batch,
|
||||||
|
BatchStatus,
|
||||||
|
CancelByBatchIDsResult,
|
||||||
|
ClearResult,
|
||||||
|
EnqueueBatchResult,
|
||||||
|
EnqueueGraphResult,
|
||||||
|
PruneResult,
|
||||||
|
SessionQueueItem,
|
||||||
|
SessionQueueItemDTO,
|
||||||
|
SessionQueueStatus,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.shared.models import CursorPaginatedResults
|
||||||
|
|
||||||
|
from ...services.graph import Graph
|
||||||
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
|
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])
|
||||||
|
|
||||||
|
|
||||||
|
class SessionQueueAndProcessorStatus(BaseModel):
|
||||||
|
"""The overall status of session queue and processor"""
|
||||||
|
|
||||||
|
queue: SessionQueueStatus
|
||||||
|
processor: SessionProcessorStatus
|
||||||
|
|
||||||
|
|
||||||
|
@session_queue_router.post(
|
||||||
|
"/{queue_id}/enqueue_graph",
|
||||||
|
operation_id="enqueue_graph",
|
||||||
|
responses={
|
||||||
|
201: {"model": EnqueueGraphResult},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def enqueue_graph(
|
||||||
|
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||||
|
graph: Graph = Body(description="The graph to enqueue"),
|
||||||
|
prepend: bool = Body(default=False, description="Whether or not to prepend this batch in the queue"),
|
||||||
|
) -> EnqueueGraphResult:
|
||||||
|
"""Enqueues a graph for single execution."""
|
||||||
|
|
||||||
|
return ApiDependencies.invoker.services.session_queue.enqueue_graph(queue_id=queue_id, graph=graph, prepend=prepend)
|
||||||
|
|
||||||
|
|
||||||
|
@session_queue_router.post(
|
||||||
|
"/{queue_id}/enqueue_batch",
|
||||||
|
operation_id="enqueue_batch",
|
||||||
|
responses={
|
||||||
|
201: {"model": EnqueueBatchResult},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def enqueue_batch(
|
||||||
|
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||||
|
batch: Batch = Body(description="Batch to process"),
|
||||||
|
prepend: bool = Body(default=False, description="Whether or not to prepend this batch in the queue"),
|
||||||
|
) -> EnqueueBatchResult:
|
||||||
|
"""Processes a batch and enqueues the output graphs for execution."""
|
||||||
|
|
||||||
|
return ApiDependencies.invoker.services.session_queue.enqueue_batch(queue_id=queue_id, batch=batch, prepend=prepend)
|
||||||
|
|
||||||
|
|
||||||
|
@session_queue_router.get(
|
||||||
|
"/{queue_id}/list",
|
||||||
|
operation_id="list_queue_items",
|
||||||
|
responses={
|
||||||
|
200: {"model": CursorPaginatedResults[SessionQueueItemDTO]},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def list_queue_items(
|
||||||
|
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||||
|
limit: int = Query(default=50, description="The number of items to fetch"),
|
||||||
|
status: Optional[QUEUE_ITEM_STATUS] = Query(default=None, description="The status of items to fetch"),
|
||||||
|
cursor: Optional[int] = Query(default=None, description="The pagination cursor"),
|
||||||
|
priority: int = Query(default=0, description="The pagination cursor priority"),
|
||||||
|
) -> CursorPaginatedResults[SessionQueueItemDTO]:
|
||||||
|
"""Gets all queue items (without graphs)"""
|
||||||
|
|
||||||
|
return ApiDependencies.invoker.services.session_queue.list_queue_items(
|
||||||
|
queue_id=queue_id, limit=limit, status=status, order_id=cursor, priority=priority
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@session_queue_router.put(
|
||||||
|
"/{queue_id}/processor/resume",
|
||||||
|
operation_id="resume",
|
||||||
|
responses={200: {"model": SessionProcessorStatus}},
|
||||||
|
)
|
||||||
|
async def resume(
|
||||||
|
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||||
|
) -> SessionProcessorStatus:
|
||||||
|
"""Resumes session processor"""
|
||||||
|
return ApiDependencies.invoker.services.session_processor.resume()
|
||||||
|
|
||||||
|
|
||||||
|
@session_queue_router.put(
|
||||||
|
"/{queue_id}/processor/pause",
|
||||||
|
operation_id="pause",
|
||||||
|
responses={200: {"model": SessionProcessorStatus}},
|
||||||
|
)
|
||||||
|
async def Pause(
|
||||||
|
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||||
|
) -> SessionProcessorStatus:
|
||||||
|
"""Pauses session processor"""
|
||||||
|
return ApiDependencies.invoker.services.session_processor.pause()
|
||||||
|
|
||||||
|
|
||||||
|
@session_queue_router.put(
|
||||||
|
"/{queue_id}/cancel_by_batch_ids",
|
||||||
|
operation_id="cancel_by_batch_ids",
|
||||||
|
responses={200: {"model": CancelByBatchIDsResult}},
|
||||||
|
)
|
||||||
|
async def cancel_by_batch_ids(
|
||||||
|
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||||
|
batch_ids: list[str] = Body(description="The list of batch_ids to cancel all queue items for", embed=True),
|
||||||
|
) -> CancelByBatchIDsResult:
|
||||||
|
"""Immediately cancels all queue items from the given batch ids"""
|
||||||
|
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
|
||||||
|
|
||||||
|
|
||||||
|
@session_queue_router.put(
|
||||||
|
"/{queue_id}/clear",
|
||||||
|
operation_id="clear",
|
||||||
|
responses={
|
||||||
|
200: {"model": ClearResult},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def clear(
|
||||||
|
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||||
|
) -> ClearResult:
|
||||||
|
"""Clears the queue entirely, immediately canceling the currently-executing session"""
|
||||||
|
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
|
||||||
|
if queue_item is not None:
|
||||||
|
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
|
||||||
|
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
|
||||||
|
return clear_result
|
||||||
|
|
||||||
|
|
||||||
|
@session_queue_router.put(
|
||||||
|
"/{queue_id}/prune",
|
||||||
|
operation_id="prune",
|
||||||
|
responses={
|
||||||
|
200: {"model": PruneResult},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def prune(
|
||||||
|
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||||
|
) -> PruneResult:
|
||||||
|
"""Prunes all completed or errored queue items"""
|
||||||
|
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
|
||||||
|
|
||||||
|
|
||||||
|
@session_queue_router.get(
|
||||||
|
"/{queue_id}/current",
|
||||||
|
operation_id="get_current_queue_item",
|
||||||
|
responses={
|
||||||
|
200: {"model": Optional[SessionQueueItem]},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def get_current_queue_item(
|
||||||
|
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||||
|
) -> Optional[SessionQueueItem]:
|
||||||
|
"""Gets the currently execution queue item"""
|
||||||
|
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
|
||||||
|
|
||||||
|
|
||||||
|
@session_queue_router.get(
|
||||||
|
"/{queue_id}/next",
|
||||||
|
operation_id="get_next_queue_item",
|
||||||
|
responses={
|
||||||
|
200: {"model": Optional[SessionQueueItem]},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def get_next_queue_item(
|
||||||
|
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||||
|
) -> Optional[SessionQueueItem]:
|
||||||
|
"""Gets the next queue item, without executing it"""
|
||||||
|
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
|
||||||
|
|
||||||
|
|
||||||
|
@session_queue_router.get(
|
||||||
|
"/{queue_id}/status",
|
||||||
|
operation_id="get_queue_status",
|
||||||
|
responses={
|
||||||
|
200: {"model": SessionQueueAndProcessorStatus},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def get_queue_status(
|
||||||
|
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||||
|
) -> SessionQueueAndProcessorStatus:
|
||||||
|
"""Gets the status of the session queue"""
|
||||||
|
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
|
||||||
|
processor = ApiDependencies.invoker.services.session_processor.get_status()
|
||||||
|
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
|
||||||
|
|
||||||
|
|
||||||
|
@session_queue_router.get(
|
||||||
|
"/{queue_id}/b/{batch_id}/status",
|
||||||
|
operation_id="get_batch_status",
|
||||||
|
responses={
|
||||||
|
200: {"model": BatchStatus},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def get_batch_status(
|
||||||
|
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||||
|
batch_id: str = Path(description="The batch to get the status of"),
|
||||||
|
) -> BatchStatus:
|
||||||
|
"""Gets the status of the session queue"""
|
||||||
|
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
|
||||||
|
|
||||||
|
|
||||||
|
@session_queue_router.get(
|
||||||
|
"/{queue_id}/i/{item_id}",
|
||||||
|
operation_id="get_queue_item",
|
||||||
|
responses={
|
||||||
|
200: {"model": SessionQueueItem},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def get_queue_item(
|
||||||
|
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||||
|
item_id: str = Path(description="The queue item to get"),
|
||||||
|
) -> SessionQueueItem:
|
||||||
|
"""Gets a queue item"""
|
||||||
|
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
|
||||||
|
|
||||||
|
|
||||||
|
@session_queue_router.put(
|
||||||
|
"/{queue_id}/i/{item_id}/cancel",
|
||||||
|
operation_id="cancel_queue_item",
|
||||||
|
responses={
|
||||||
|
200: {"model": SessionQueueItem},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def cancel_queue_item(
|
||||||
|
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||||
|
item_id: str = Path(description="The queue item to cancel"),
|
||||||
|
) -> SessionQueueItem:
|
||||||
|
"""Deletes a queue item"""
|
||||||
|
|
||||||
|
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)
|
@ -23,12 +23,14 @@ session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
|||||||
200: {"model": GraphExecutionState},
|
200: {"model": GraphExecutionState},
|
||||||
400: {"description": "Invalid json"},
|
400: {"description": "Invalid json"},
|
||||||
},
|
},
|
||||||
|
deprecated=True,
|
||||||
)
|
)
|
||||||
async def create_session(
|
async def create_session(
|
||||||
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with")
|
queue_id: str = Query(default="", description="The id of the queue to associate the session with"),
|
||||||
|
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
|
||||||
) -> GraphExecutionState:
|
) -> GraphExecutionState:
|
||||||
"""Creates a new session, optionally initializing it with an invocation graph"""
|
"""Creates a new session, optionally initializing it with an invocation graph"""
|
||||||
session = ApiDependencies.invoker.create_execution_state(graph)
|
session = ApiDependencies.invoker.create_execution_state(queue_id=queue_id, graph=graph)
|
||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
@ -36,6 +38,7 @@ async def create_session(
|
|||||||
"/",
|
"/",
|
||||||
operation_id="list_sessions",
|
operation_id="list_sessions",
|
||||||
responses={200: {"model": PaginatedResults[GraphExecutionState]}},
|
responses={200: {"model": PaginatedResults[GraphExecutionState]}},
|
||||||
|
deprecated=True,
|
||||||
)
|
)
|
||||||
async def list_sessions(
|
async def list_sessions(
|
||||||
page: int = Query(default=0, description="The page of results to get"),
|
page: int = Query(default=0, description="The page of results to get"),
|
||||||
@ -57,6 +60,7 @@ async def list_sessions(
|
|||||||
200: {"model": GraphExecutionState},
|
200: {"model": GraphExecutionState},
|
||||||
404: {"description": "Session not found"},
|
404: {"description": "Session not found"},
|
||||||
},
|
},
|
||||||
|
deprecated=True,
|
||||||
)
|
)
|
||||||
async def get_session(
|
async def get_session(
|
||||||
session_id: str = Path(description="The id of the session to get"),
|
session_id: str = Path(description="The id of the session to get"),
|
||||||
@ -77,6 +81,7 @@ async def get_session(
|
|||||||
400: {"description": "Invalid node or link"},
|
400: {"description": "Invalid node or link"},
|
||||||
404: {"description": "Session not found"},
|
404: {"description": "Session not found"},
|
||||||
},
|
},
|
||||||
|
deprecated=True,
|
||||||
)
|
)
|
||||||
async def add_node(
|
async def add_node(
|
||||||
session_id: str = Path(description="The id of the session"),
|
session_id: str = Path(description="The id of the session"),
|
||||||
@ -109,6 +114,7 @@ async def add_node(
|
|||||||
400: {"description": "Invalid node or link"},
|
400: {"description": "Invalid node or link"},
|
||||||
404: {"description": "Session not found"},
|
404: {"description": "Session not found"},
|
||||||
},
|
},
|
||||||
|
deprecated=True,
|
||||||
)
|
)
|
||||||
async def update_node(
|
async def update_node(
|
||||||
session_id: str = Path(description="The id of the session"),
|
session_id: str = Path(description="The id of the session"),
|
||||||
@ -142,6 +148,7 @@ async def update_node(
|
|||||||
400: {"description": "Invalid node or link"},
|
400: {"description": "Invalid node or link"},
|
||||||
404: {"description": "Session not found"},
|
404: {"description": "Session not found"},
|
||||||
},
|
},
|
||||||
|
deprecated=True,
|
||||||
)
|
)
|
||||||
async def delete_node(
|
async def delete_node(
|
||||||
session_id: str = Path(description="The id of the session"),
|
session_id: str = Path(description="The id of the session"),
|
||||||
@ -172,6 +179,7 @@ async def delete_node(
|
|||||||
400: {"description": "Invalid node or link"},
|
400: {"description": "Invalid node or link"},
|
||||||
404: {"description": "Session not found"},
|
404: {"description": "Session not found"},
|
||||||
},
|
},
|
||||||
|
deprecated=True,
|
||||||
)
|
)
|
||||||
async def add_edge(
|
async def add_edge(
|
||||||
session_id: str = Path(description="The id of the session"),
|
session_id: str = Path(description="The id of the session"),
|
||||||
@ -203,6 +211,7 @@ async def add_edge(
|
|||||||
400: {"description": "Invalid node or link"},
|
400: {"description": "Invalid node or link"},
|
||||||
404: {"description": "Session not found"},
|
404: {"description": "Session not found"},
|
||||||
},
|
},
|
||||||
|
deprecated=True,
|
||||||
)
|
)
|
||||||
async def delete_edge(
|
async def delete_edge(
|
||||||
session_id: str = Path(description="The id of the session"),
|
session_id: str = Path(description="The id of the session"),
|
||||||
@ -241,8 +250,10 @@ async def delete_edge(
|
|||||||
400: {"description": "The session has no invocations ready to invoke"},
|
400: {"description": "The session has no invocations ready to invoke"},
|
||||||
404: {"description": "Session not found"},
|
404: {"description": "Session not found"},
|
||||||
},
|
},
|
||||||
|
deprecated=True,
|
||||||
)
|
)
|
||||||
async def invoke_session(
|
async def invoke_session(
|
||||||
|
queue_id: str = Query(description="The id of the queue to associate the session with"),
|
||||||
session_id: str = Path(description="The id of the session to invoke"),
|
session_id: str = Path(description="The id of the session to invoke"),
|
||||||
all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"),
|
all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
@ -254,7 +265,7 @@ async def invoke_session(
|
|||||||
if session.is_complete():
|
if session.is_complete():
|
||||||
raise HTTPException(status_code=400)
|
raise HTTPException(status_code=400)
|
||||||
|
|
||||||
ApiDependencies.invoker.invoke(session, invoke_all=all)
|
ApiDependencies.invoker.invoke(queue_id, session, invoke_all=all)
|
||||||
return Response(status_code=202)
|
return Response(status_code=202)
|
||||||
|
|
||||||
|
|
||||||
@ -262,6 +273,7 @@ async def invoke_session(
|
|||||||
"/{session_id}/invoke",
|
"/{session_id}/invoke",
|
||||||
operation_id="cancel_session_invoke",
|
operation_id="cancel_session_invoke",
|
||||||
responses={202: {"description": "The invocation is canceled"}},
|
responses={202: {"description": "The invocation is canceled"}},
|
||||||
|
deprecated=True,
|
||||||
)
|
)
|
||||||
async def cancel_session_invoke(
|
async def cancel_session_invoke(
|
||||||
session_id: str = Path(description="The id of the session to cancel"),
|
session_id: str = Path(description="The id of the session to cancel"),
|
||||||
|
41
invokeai/app/api/routers/utilities.py
Normal file
41
invokeai/app/api/routers/utilities.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
|
||||||
|
from fastapi import Body
|
||||||
|
from fastapi.routing import APIRouter
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from pyparsing import ParseException
|
||||||
|
|
||||||
|
utilities_router = APIRouter(prefix="/v1/utilities", tags=["utilities"])
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicPromptsResponse(BaseModel):
|
||||||
|
prompts: list[str]
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@utilities_router.post(
|
||||||
|
"/dynamicprompts",
|
||||||
|
operation_id="parse_dynamicprompts",
|
||||||
|
responses={
|
||||||
|
200: {"model": DynamicPromptsResponse},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def parse_dynamicprompts(
|
||||||
|
prompt: str = Body(description="The prompt to parse with dynamicprompts"),
|
||||||
|
max_prompts: int = Body(default=1000, description="The max number of prompts to generate"),
|
||||||
|
combinatorial: bool = Body(default=True, description="Whether to use the combinatorial generator"),
|
||||||
|
) -> DynamicPromptsResponse:
|
||||||
|
"""Creates a batch process"""
|
||||||
|
try:
|
||||||
|
error: Optional[str] = None
|
||||||
|
if combinatorial:
|
||||||
|
generator = CombinatorialPromptGenerator()
|
||||||
|
prompts = generator.generate(prompt, max_prompts=max_prompts)
|
||||||
|
else:
|
||||||
|
generator = RandomPromptGenerator()
|
||||||
|
prompts = generator.generate(prompt, num_images=max_prompts)
|
||||||
|
except ParseException as e:
|
||||||
|
prompts = [prompt]
|
||||||
|
error = str(e)
|
||||||
|
return DynamicPromptsResponse(prompts=prompts if prompts else [""], error=error)
|
@ -13,24 +13,22 @@ class SocketIO:
|
|||||||
|
|
||||||
def __init__(self, app: FastAPI):
|
def __init__(self, app: FastAPI):
|
||||||
self.__sio = SocketManager(app=app)
|
self.__sio = SocketManager(app=app)
|
||||||
self.__sio.on("subscribe", handler=self._handle_sub)
|
|
||||||
self.__sio.on("unsubscribe", handler=self._handle_unsub)
|
|
||||||
|
|
||||||
local_handler.register(event_name=EventServiceBase.session_event, _func=self._handle_session_event)
|
self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
|
||||||
|
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)
|
||||||
|
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
|
||||||
|
|
||||||
async def _handle_session_event(self, event: Event):
|
async def _handle_queue_event(self, event: Event):
|
||||||
await self.__sio.emit(
|
await self.__sio.emit(
|
||||||
event=event[1]["event"],
|
event=event[1]["event"],
|
||||||
data=event[1]["data"],
|
data=event[1]["data"],
|
||||||
room=event[1]["data"]["graph_execution_state_id"],
|
room=event[1]["data"]["queue_id"],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_sub(self, sid, data, *args, **kwargs):
|
async def _handle_sub_queue(self, sid, data, *args, **kwargs):
|
||||||
if "session" in data:
|
if "queue_id" in data:
|
||||||
self.__sio.enter_room(sid, data["session"])
|
self.__sio.enter_room(sid, data["queue_id"])
|
||||||
|
|
||||||
# @app.sio.on('unsubscribe')
|
async def _handle_unsub_queue(self, sid, data, *args, **kwargs):
|
||||||
|
if "queue_id" in data:
|
||||||
async def _handle_unsub(self, sid, data, *args, **kwargs):
|
self.__sio.enter_room(sid, data["queue_id"])
|
||||||
if "session" in data:
|
|
||||||
self.__sio.leave_room(sid, data["session"])
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
|
||||||
from .services.config import InvokeAIAppConfig
|
from .services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
|
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
|
||||||
@ -33,7 +32,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
|||||||
|
|
||||||
from ..backend.util.logging import InvokeAILogger
|
from ..backend.util.logging import InvokeAILogger
|
||||||
from .api.dependencies import ApiDependencies
|
from .api.dependencies import ApiDependencies
|
||||||
from .api.routers import app_info, board_images, boards, images, models, sessions
|
from .api.routers import app_info, board_images, boards, images, models, session_queue, sessions, utilities
|
||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
|
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
|
||||||
|
|
||||||
@ -92,6 +91,8 @@ async def shutdown_event():
|
|||||||
|
|
||||||
app.include_router(sessions.session_router, prefix="/api")
|
app.include_router(sessions.session_router, prefix="/api")
|
||||||
|
|
||||||
|
app.include_router(utilities.utilities_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(models.models_router, prefix="/api")
|
app.include_router(models.models_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(images.images_router, prefix="/api")
|
app.include_router(images.images_router, prefix="/api")
|
||||||
@ -102,6 +103,8 @@ app.include_router(board_images.board_images_router, prefix="/api")
|
|||||||
|
|
||||||
app.include_router(app_info.app_router, prefix="/api")
|
app.include_router(app_info.app_router, prefix="/api")
|
||||||
|
|
||||||
|
app.include_router(session_queue.session_queue_router, prefix="/api")
|
||||||
|
|
||||||
|
|
||||||
# Build a custom OpenAPI to include all outputs
|
# Build a custom OpenAPI to include all outputs
|
||||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||||
|
|
||||||
|
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||||
|
|
||||||
from .services.config import InvokeAIAppConfig
|
from .services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
@ -12,6 +14,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
|||||||
import argparse
|
import argparse
|
||||||
import re
|
import re
|
||||||
import shlex
|
import shlex
|
||||||
|
import sqlite3
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from typing import Optional, Union, get_type_hints
|
from typing import Optional, Union, get_type_hints
|
||||||
@ -249,19 +252,18 @@ def invoke_cli():
|
|||||||
db_location = config.db_path
|
db_location = config.db_path
|
||||||
db_location.parent.mkdir(parents=True, exist_ok=True)
|
db_location.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
|
||||||
logger.info(f'InvokeAI database location is "{db_location}"')
|
logger.info(f'InvokeAI database location is "{db_location}"')
|
||||||
|
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
||||||
filename=db_location, table_name="graph_executions"
|
|
||||||
)
|
|
||||||
|
|
||||||
urls = LocalUrlService()
|
urls = LocalUrlService()
|
||||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
image_record_storage = SqliteImageRecordStorage(conn=db_conn)
|
||||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
names = SimpleNameService()
|
names = SimpleNameService()
|
||||||
|
|
||||||
board_record_storage = SqliteBoardRecordStorage(db_location)
|
board_record_storage = SqliteBoardRecordStorage(conn=db_conn)
|
||||||
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn)
|
||||||
|
|
||||||
boards = BoardService(
|
boards = BoardService(
|
||||||
services=BoardServiceDependencies(
|
services=BoardServiceDependencies(
|
||||||
@ -303,12 +305,13 @@ def invoke_cli():
|
|||||||
boards=boards,
|
boards=boards,
|
||||||
board_images=board_images,
|
board_images=board_images,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||||
logger=logger,
|
logger=logger,
|
||||||
configuration=config,
|
configuration=config,
|
||||||
|
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size),
|
||||||
)
|
)
|
||||||
|
|
||||||
system_graphs = create_system_graphs(services.graph_library)
|
system_graphs = create_system_graphs(services.graph_library)
|
||||||
|
@ -417,12 +417,18 @@ class UIConfigBase(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class InvocationContext:
|
class InvocationContext:
|
||||||
|
"""Initialized and provided to on execution of invocations."""
|
||||||
|
|
||||||
services: InvocationServices
|
services: InvocationServices
|
||||||
graph_execution_state_id: str
|
graph_execution_state_id: str
|
||||||
|
queue_id: str
|
||||||
|
queue_item_id: str
|
||||||
|
|
||||||
def __init__(self, services: InvocationServices, graph_execution_state_id: str):
|
def __init__(self, services: InvocationServices, queue_id: str, queue_item_id: str, graph_execution_state_id: str):
|
||||||
self.services = services
|
self.services = services
|
||||||
self.graph_execution_state_id = graph_execution_state_id
|
self.graph_execution_state_id = graph_execution_state_id
|
||||||
|
self.queue_id = queue_id
|
||||||
|
self.queue_item_id = queue_item_id
|
||||||
|
|
||||||
|
|
||||||
class BaseInvocationOutput(BaseModel):
|
class BaseInvocationOutput(BaseModel):
|
||||||
@ -520,6 +526,9 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
return signature(cls.invoke).return_annotation
|
return signature(cls.invoke).return_annotation
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
validate_assignment = True
|
||||||
|
validate_all = True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||||
uiconfig = getattr(model_class, "UIConfig", None)
|
uiconfig = getattr(model_class, "UIConfig", None)
|
||||||
@ -568,7 +577,29 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
raise RequiredConnectionException(self.__fields__["type"].default, field_name)
|
raise RequiredConnectionException(self.__fields__["type"].default, field_name)
|
||||||
elif _input == Input.Any:
|
elif _input == Input.Any:
|
||||||
raise MissingInputException(self.__fields__["type"].default, field_name)
|
raise MissingInputException(self.__fields__["type"].default, field_name)
|
||||||
return self.invoke(context)
|
|
||||||
|
# skip node cache codepath if it's disabled
|
||||||
|
if context.services.configuration.node_cache_size == 0:
|
||||||
|
return self.invoke(context)
|
||||||
|
|
||||||
|
output: BaseInvocationOutput
|
||||||
|
if self.use_cache:
|
||||||
|
key = context.services.invocation_cache.create_key(self)
|
||||||
|
cached_value = context.services.invocation_cache.get(key)
|
||||||
|
if cached_value is None:
|
||||||
|
context.services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}')
|
||||||
|
output = self.invoke(context)
|
||||||
|
context.services.invocation_cache.save(output)
|
||||||
|
return output
|
||||||
|
else:
|
||||||
|
context.services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}')
|
||||||
|
return cached_value
|
||||||
|
else:
|
||||||
|
context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
|
||||||
|
return self.invoke(context)
|
||||||
|
|
||||||
|
def get_type(self) -> str:
|
||||||
|
return self.__fields__["type"].default
|
||||||
|
|
||||||
id: str = Field(
|
id: str = Field(
|
||||||
description="The id of this instance of an invocation. Must be unique among all instances of invocations."
|
description="The id of this instance of an invocation. Must be unique among all instances of invocations."
|
||||||
@ -581,6 +612,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
description="The workflow to save with the image",
|
description="The workflow to save with the image",
|
||||||
ui_type=UIType.WorkflowField,
|
ui_type=UIType.WorkflowField,
|
||||||
)
|
)
|
||||||
|
use_cache: bool = InputField(default=True, description="Whether or not to use the cache")
|
||||||
|
|
||||||
@validator("workflow", pre=True)
|
@validator("workflow", pre=True)
|
||||||
def validate_workflow_is_json(cls, v):
|
def validate_workflow_is_json(cls, v):
|
||||||
@ -604,6 +636,7 @@ def invocation(
|
|||||||
tags: Optional[list[str]] = None,
|
tags: Optional[list[str]] = None,
|
||||||
category: Optional[str] = None,
|
category: Optional[str] = None,
|
||||||
version: Optional[str] = None,
|
version: Optional[str] = None,
|
||||||
|
use_cache: Optional[bool] = True,
|
||||||
) -> Callable[[Type[GenericBaseInvocation]], Type[GenericBaseInvocation]]:
|
) -> Callable[[Type[GenericBaseInvocation]], Type[GenericBaseInvocation]]:
|
||||||
"""
|
"""
|
||||||
Adds metadata to an invocation.
|
Adds metadata to an invocation.
|
||||||
@ -636,6 +669,8 @@ def invocation(
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
|
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
|
||||||
cls.UIConfig.version = version
|
cls.UIConfig.version = version
|
||||||
|
if use_cache is not None:
|
||||||
|
cls.__fields__["use_cache"].default = use_cache
|
||||||
|
|
||||||
# Add the invocation type to the pydantic model of the invocation
|
# Add the invocation type to the pydantic model of the invocation
|
||||||
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
||||||
|
@ -56,6 +56,7 @@ class RangeOfSizeInvocation(BaseInvocation):
|
|||||||
tags=["range", "integer", "random", "collection"],
|
tags=["range", "integer", "random", "collection"],
|
||||||
category="collections",
|
category="collections",
|
||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
|
use_cache=False,
|
||||||
)
|
)
|
||||||
class RandomRangeInvocation(BaseInvocation):
|
class RandomRangeInvocation(BaseInvocation):
|
||||||
"""Creates a collection of random numbers"""
|
"""Creates a collection of random numbers"""
|
||||||
|
@ -965,3 +965,42 @@ class ImageChannelMultiplyInvocation(BaseInvocation):
|
|||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"save_image",
|
||||||
|
title="Save Image",
|
||||||
|
tags=["primitives", "image"],
|
||||||
|
category="primitives",
|
||||||
|
version="1.0.0",
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
class SaveImageInvocation(BaseInvocation):
|
||||||
|
"""Saves an image. Unlike an image primitive, this invocation stores a copy of the image."""
|
||||||
|
|
||||||
|
image: ImageField = InputField(description="The image to load")
|
||||||
|
metadata: CoreMetadata = InputField(
|
||||||
|
default=None,
|
||||||
|
description=FieldDescriptions.core_metadata,
|
||||||
|
ui_hidden=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
|
image_dto = context.services.images.create(
|
||||||
|
image=image,
|
||||||
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
|
metadata=self.metadata.dict() if self.metadata else None,
|
||||||
|
workflow=self.workflow,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
|
)
|
||||||
|
@ -54,7 +54,14 @@ class DivideInvocation(BaseInvocation):
|
|||||||
return IntegerOutput(value=int(self.a / self.b))
|
return IntegerOutput(value=int(self.a / self.b))
|
||||||
|
|
||||||
|
|
||||||
@invocation("rand_int", title="Random Integer", tags=["math", "random"], category="math", version="1.0.0")
|
@invocation(
|
||||||
|
"rand_int",
|
||||||
|
title="Random Integer",
|
||||||
|
tags=["math", "random"],
|
||||||
|
category="math",
|
||||||
|
version="1.0.0",
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
class RandomIntInvocation(BaseInvocation):
|
class RandomIntInvocation(BaseInvocation):
|
||||||
"""Outputs a single random integer."""
|
"""Outputs a single random integer."""
|
||||||
|
|
||||||
|
@ -10,7 +10,14 @@ from invokeai.app.invocations.primitives import StringCollectionOutput
|
|||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, invocation
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, invocation
|
||||||
|
|
||||||
|
|
||||||
@invocation("dynamic_prompt", title="Dynamic Prompt", tags=["prompt", "collection"], category="prompt", version="1.0.0")
|
@invocation(
|
||||||
|
"dynamic_prompt",
|
||||||
|
title="Dynamic Prompt",
|
||||||
|
tags=["prompt", "collection"],
|
||||||
|
category="prompt",
|
||||||
|
version="1.0.0",
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
class DynamicPromptInvocation(BaseInvocation):
|
class DynamicPromptInvocation(BaseInvocation):
|
||||||
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
||||||
|
|
||||||
|
@ -53,24 +53,20 @@ class BoardImageRecordStorageBase(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||||
_filename: str
|
|
||||||
_conn: sqlite3.Connection
|
_conn: sqlite3.Connection
|
||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_lock: threading.Lock
|
_lock: threading.Lock
|
||||||
|
|
||||||
def __init__(self, filename: str) -> None:
|
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._filename = filename
|
self._conn = conn
|
||||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
|
||||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||||
self._conn.row_factory = sqlite3.Row
|
self._conn.row_factory = sqlite3.Row
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
self._lock = threading.Lock()
|
self._lock = lock
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
# Enable foreign keys
|
|
||||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
|
||||||
self._create_tables()
|
self._create_tables()
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
finally:
|
finally:
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional, Union, cast
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
@ -8,6 +7,7 @@ from pydantic import BaseModel, Extra, Field
|
|||||||
|
|
||||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||||
from invokeai.app.services.models.board_record import BoardRecord, deserialize_board_record
|
from invokeai.app.services.models.board_record import BoardRecord, deserialize_board_record
|
||||||
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
|
|
||||||
class BoardChanges(BaseModel, extra=Extra.forbid):
|
class BoardChanges(BaseModel, extra=Extra.forbid):
|
||||||
@ -87,24 +87,20 @@ class BoardRecordStorageBase(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||||
_filename: str
|
|
||||||
_conn: sqlite3.Connection
|
_conn: sqlite3.Connection
|
||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_lock: threading.Lock
|
_lock: threading.Lock
|
||||||
|
|
||||||
def __init__(self, filename: str) -> None:
|
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._filename = filename
|
self._conn = conn
|
||||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
|
||||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||||
self._conn.row_factory = sqlite3.Row
|
self._conn.row_factory = sqlite3.Row
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
self._lock = threading.Lock()
|
self._lock = lock
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
# Enable foreign keys
|
|
||||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
|
||||||
self._create_tables()
|
self._create_tables()
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
finally:
|
finally:
|
||||||
@ -174,7 +170,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
|||||||
board_name: str,
|
board_name: str,
|
||||||
) -> BoardRecord:
|
) -> BoardRecord:
|
||||||
try:
|
try:
|
||||||
board_id = str(uuid.uuid4())
|
board_id = uuid_string()
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
|
@ -16,7 +16,7 @@ import pydoc
|
|||||||
import sys
|
import sys
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import ClassVar, Dict, List, Literal, Union, get_args, get_origin, get_type_hints
|
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
||||||
|
|
||||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||||
from pydantic import BaseSettings
|
from pydantic import BaseSettings
|
||||||
@ -39,10 +39,10 @@ class InvokeAISettings(BaseSettings):
|
|||||||
read from an omegaconf .yaml file.
|
read from an omegaconf .yaml file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
initconf: ClassVar[DictConfig] = None
|
initconf: ClassVar[Optional[DictConfig]] = None
|
||||||
argparse_groups: ClassVar[Dict] = {}
|
argparse_groups: ClassVar[Dict] = {}
|
||||||
|
|
||||||
def parse_args(self, argv: list = sys.argv[1:]):
|
def parse_args(self, argv: Optional[list] = sys.argv[1:]):
|
||||||
parser = self.get_parser()
|
parser = self.get_parser()
|
||||||
opt, unknown_opts = parser.parse_known_args(argv)
|
opt, unknown_opts = parser.parse_known_args(argv)
|
||||||
if len(unknown_opts) > 0:
|
if len(unknown_opts) > 0:
|
||||||
@ -83,7 +83,8 @@ class InvokeAISettings(BaseSettings):
|
|||||||
else:
|
else:
|
||||||
settings_stanza = "Uncategorized"
|
settings_stanza = "Uncategorized"
|
||||||
|
|
||||||
env_prefix = cls.Config.env_prefix if hasattr(cls.Config, "env_prefix") else settings_stanza.upper()
|
env_prefix = getattr(cls.Config, "env_prefix", None)
|
||||||
|
env_prefix = env_prefix if env_prefix is not None else settings_stanza.upper()
|
||||||
|
|
||||||
initconf = (
|
initconf = (
|
||||||
cls.initconf.get(settings_stanza)
|
cls.initconf.get(settings_stanza)
|
||||||
@ -116,8 +117,8 @@ class InvokeAISettings(BaseSettings):
|
|||||||
field.default = current_default
|
field.default = current_default
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def cmd_name(self, command_field: str = "type") -> str:
|
def cmd_name(cls, command_field: str = "type") -> str:
|
||||||
hints = get_type_hints(self)
|
hints = get_type_hints(cls)
|
||||||
if command_field in hints:
|
if command_field in hints:
|
||||||
return get_args(hints[command_field])[0]
|
return get_args(hints[command_field])[0]
|
||||||
else:
|
else:
|
||||||
@ -133,16 +134,12 @@ class InvokeAISettings(BaseSettings):
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_subparser(cls, parser: argparse.ArgumentParser):
|
def _excluded(cls) -> List[str]:
|
||||||
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _excluded(self) -> List[str]:
|
|
||||||
# internal fields that shouldn't be exposed as command line options
|
# internal fields that shouldn't be exposed as command line options
|
||||||
return ["type", "initconf"]
|
return ["type", "initconf"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _excluded_from_yaml(self) -> List[str]:
|
def _excluded_from_yaml(cls) -> List[str]:
|
||||||
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
|
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
|
||||||
return [
|
return [
|
||||||
"type",
|
"type",
|
||||||
|
@ -194,8 +194,8 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
setting environment variables INVOKEAI_<setting>.
|
setting environment variables INVOKEAI_<setting>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
singleton_config: ClassVar[InvokeAIAppConfig] = None
|
singleton_config: ClassVar[Optional[InvokeAIAppConfig]] = None
|
||||||
singleton_init: ClassVar[Dict] = None
|
singleton_init: ClassVar[Optional[Dict]] = None
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["InvokeAI"] = "InvokeAI"
|
type: Literal["InvokeAI"] = "InvokeAI"
|
||||||
@ -234,6 +234,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
||||||
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
|
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
|
||||||
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
||||||
|
log_sql : bool = Field(default=False, description="Log SQL queries", category="Logging")
|
||||||
|
|
||||||
dev_reload : bool = Field(default=False, description="Automatically reload when Python sources are changed.", category="Development")
|
dev_reload : bool = Field(default=False, description="Automatically reload when Python sources are changed.", category="Development")
|
||||||
|
|
||||||
@ -245,18 +246,23 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", )
|
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", )
|
||||||
|
|
||||||
# DEVICE
|
# DEVICE
|
||||||
device : Literal[tuple(["auto", "cpu", "cuda", "cuda:1", "mps"])] = Field(default="auto", description="Generation device", category="Device", )
|
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", category="Device", )
|
||||||
precision: Literal[tuple(["auto", "float16", "float32", "autocast"])] = Field(default="auto", description="Floating point precision", category="Device", )
|
precision : Literal["auto", "float16", "float32", "autocast"] = Field(default="auto", description="Floating point precision", category="Device", )
|
||||||
|
|
||||||
# GENERATION
|
# GENERATION
|
||||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category="Generation", )
|
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category="Generation", )
|
||||||
attention_type : Literal[tuple(["auto", "normal", "xformers", "sliced", "torch-sdp"])] = Field(default="auto", description="Attention type", category="Generation", )
|
attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type", category="Generation", )
|
||||||
attention_slice_size: Literal[tuple(["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8])] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
|
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
|
||||||
|
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
||||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
||||||
|
|
||||||
|
# QUEUE
|
||||||
|
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", category="Queue", )
|
||||||
|
|
||||||
# NODES
|
# NODES
|
||||||
allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", category="Nodes")
|
allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", category="Nodes")
|
||||||
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", category="Nodes")
|
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", category="Nodes")
|
||||||
|
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", category="Nodes", )
|
||||||
|
|
||||||
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
||||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
||||||
@ -272,7 +278,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
class Config:
|
class Config:
|
||||||
validate_assignment = True
|
validate_assignment = True
|
||||||
|
|
||||||
def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False):
|
def parse_args(self, argv: Optional[list[str]] = None, conf: Optional[DictConfig] = None, clobber=False):
|
||||||
"""
|
"""
|
||||||
Update settings with contents of init file, environment, and
|
Update settings with contents of init file, environment, and
|
||||||
command-line settings.
|
command-line settings.
|
||||||
@ -283,12 +289,16 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
# Set the runtime root directory. We parse command-line switches here
|
# Set the runtime root directory. We parse command-line switches here
|
||||||
# in order to pick up the --root_dir option.
|
# in order to pick up the --root_dir option.
|
||||||
super().parse_args(argv)
|
super().parse_args(argv)
|
||||||
|
loaded_conf = None
|
||||||
if conf is None:
|
if conf is None:
|
||||||
try:
|
try:
|
||||||
conf = OmegaConf.load(self.root_dir / INIT_FILE)
|
loaded_conf = OmegaConf.load(self.root_dir / INIT_FILE)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
InvokeAISettings.initconf = conf
|
if isinstance(loaded_conf, DictConfig):
|
||||||
|
InvokeAISettings.initconf = loaded_conf
|
||||||
|
else:
|
||||||
|
InvokeAISettings.initconf = conf
|
||||||
|
|
||||||
# parse args again in order to pick up settings in configuration file
|
# parse args again in order to pick up settings in configuration file
|
||||||
super().parse_args(argv)
|
super().parse_args(argv)
|
||||||
@ -376,13 +386,6 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
"""
|
"""
|
||||||
return self._resolve(self.models_dir)
|
return self._resolve(self.models_dir)
|
||||||
|
|
||||||
@property
|
|
||||||
def autoconvert_path(self) -> Path:
|
|
||||||
"""
|
|
||||||
Path to the directory containing models to be imported automatically at startup.
|
|
||||||
"""
|
|
||||||
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
|
|
||||||
|
|
||||||
# the following methods support legacy calls leftover from the Globals era
|
# the following methods support legacy calls leftover from the Globals era
|
||||||
@property
|
@property
|
||||||
def full_precision(self) -> bool:
|
def full_precision(self) -> bool:
|
||||||
@ -405,11 +408,11 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ram_cache_size(self) -> float:
|
def ram_cache_size(self) -> Union[Literal["auto"], float]:
|
||||||
return self.max_cache_size or self.ram
|
return self.max_cache_size or self.ram
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vram_cache_size(self) -> float:
|
def vram_cache_size(self) -> Union[Literal["auto"], float]:
|
||||||
return self.max_vram_cache_size or self.vram
|
return self.max_vram_cache_size or self.vram
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -10,57 +10,58 @@ default_text_to_image_graph_id = "539b2af5-2b4d-4d8c-8071-e54a3255fc74"
|
|||||||
|
|
||||||
|
|
||||||
def create_text_to_image() -> LibraryGraph:
|
def create_text_to_image() -> LibraryGraph:
|
||||||
|
graph = Graph(
|
||||||
|
nodes={
|
||||||
|
"width": IntegerInvocation(id="width", value=512),
|
||||||
|
"height": IntegerInvocation(id="height", value=512),
|
||||||
|
"seed": IntegerInvocation(id="seed", value=-1),
|
||||||
|
"3": NoiseInvocation(id="3"),
|
||||||
|
"4": CompelInvocation(id="4"),
|
||||||
|
"5": CompelInvocation(id="5"),
|
||||||
|
"6": DenoiseLatentsInvocation(id="6"),
|
||||||
|
"7": LatentsToImageInvocation(id="7"),
|
||||||
|
"8": ImageNSFWBlurInvocation(id="8"),
|
||||||
|
},
|
||||||
|
edges=[
|
||||||
|
Edge(
|
||||||
|
source=EdgeConnection(node_id="width", field="value"),
|
||||||
|
destination=EdgeConnection(node_id="3", field="width"),
|
||||||
|
),
|
||||||
|
Edge(
|
||||||
|
source=EdgeConnection(node_id="height", field="value"),
|
||||||
|
destination=EdgeConnection(node_id="3", field="height"),
|
||||||
|
),
|
||||||
|
Edge(
|
||||||
|
source=EdgeConnection(node_id="seed", field="value"),
|
||||||
|
destination=EdgeConnection(node_id="3", field="seed"),
|
||||||
|
),
|
||||||
|
Edge(
|
||||||
|
source=EdgeConnection(node_id="3", field="noise"),
|
||||||
|
destination=EdgeConnection(node_id="6", field="noise"),
|
||||||
|
),
|
||||||
|
Edge(
|
||||||
|
source=EdgeConnection(node_id="6", field="latents"),
|
||||||
|
destination=EdgeConnection(node_id="7", field="latents"),
|
||||||
|
),
|
||||||
|
Edge(
|
||||||
|
source=EdgeConnection(node_id="4", field="conditioning"),
|
||||||
|
destination=EdgeConnection(node_id="6", field="positive_conditioning"),
|
||||||
|
),
|
||||||
|
Edge(
|
||||||
|
source=EdgeConnection(node_id="5", field="conditioning"),
|
||||||
|
destination=EdgeConnection(node_id="6", field="negative_conditioning"),
|
||||||
|
),
|
||||||
|
Edge(
|
||||||
|
source=EdgeConnection(node_id="7", field="image"),
|
||||||
|
destination=EdgeConnection(node_id="8", field="image"),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
return LibraryGraph(
|
return LibraryGraph(
|
||||||
id=default_text_to_image_graph_id,
|
id=default_text_to_image_graph_id,
|
||||||
name="t2i",
|
name="t2i",
|
||||||
description="Converts text to an image",
|
description="Converts text to an image",
|
||||||
graph=Graph(
|
graph=graph,
|
||||||
nodes={
|
|
||||||
"width": IntegerInvocation(id="width", value=512),
|
|
||||||
"height": IntegerInvocation(id="height", value=512),
|
|
||||||
"seed": IntegerInvocation(id="seed", value=-1),
|
|
||||||
"3": NoiseInvocation(id="3"),
|
|
||||||
"4": CompelInvocation(id="4"),
|
|
||||||
"5": CompelInvocation(id="5"),
|
|
||||||
"6": DenoiseLatentsInvocation(id="6"),
|
|
||||||
"7": LatentsToImageInvocation(id="7"),
|
|
||||||
"8": ImageNSFWBlurInvocation(id="8"),
|
|
||||||
},
|
|
||||||
edges=[
|
|
||||||
Edge(
|
|
||||||
source=EdgeConnection(node_id="width", field="value"),
|
|
||||||
destination=EdgeConnection(node_id="3", field="width"),
|
|
||||||
),
|
|
||||||
Edge(
|
|
||||||
source=EdgeConnection(node_id="height", field="value"),
|
|
||||||
destination=EdgeConnection(node_id="3", field="height"),
|
|
||||||
),
|
|
||||||
Edge(
|
|
||||||
source=EdgeConnection(node_id="seed", field="value"),
|
|
||||||
destination=EdgeConnection(node_id="3", field="seed"),
|
|
||||||
),
|
|
||||||
Edge(
|
|
||||||
source=EdgeConnection(node_id="3", field="noise"),
|
|
||||||
destination=EdgeConnection(node_id="6", field="noise"),
|
|
||||||
),
|
|
||||||
Edge(
|
|
||||||
source=EdgeConnection(node_id="6", field="latents"),
|
|
||||||
destination=EdgeConnection(node_id="7", field="latents"),
|
|
||||||
),
|
|
||||||
Edge(
|
|
||||||
source=EdgeConnection(node_id="4", field="conditioning"),
|
|
||||||
destination=EdgeConnection(node_id="6", field="positive_conditioning"),
|
|
||||||
),
|
|
||||||
Edge(
|
|
||||||
source=EdgeConnection(node_id="5", field="conditioning"),
|
|
||||||
destination=EdgeConnection(node_id="6", field="negative_conditioning"),
|
|
||||||
),
|
|
||||||
Edge(
|
|
||||||
source=EdgeConnection(node_id="7", field="image"),
|
|
||||||
destination=EdgeConnection(node_id="8", field="image"),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
exposed_inputs=[
|
exposed_inputs=[
|
||||||
ExposedNodeInput(node_path="4", field="prompt", alias="positive_prompt"),
|
ExposedNodeInput(node_path="4", field="prompt", alias="positive_prompt"),
|
||||||
ExposedNodeInput(node_path="5", field="prompt", alias="negative_prompt"),
|
ExposedNodeInput(node_path="5", field="prompt", alias="negative_prompt"),
|
||||||
|
@ -4,21 +4,23 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
from invokeai.app.models.image import ProgressImage
|
from invokeai.app.models.image import ProgressImage
|
||||||
from invokeai.app.services.model_manager_service import BaseModelType, ModelInfo, ModelType, SubModelType
|
from invokeai.app.services.model_manager_service import BaseModelType, ModelInfo, ModelType, SubModelType
|
||||||
|
from invokeai.app.services.session_queue.session_queue_common import EnqueueBatchResult, SessionQueueItem
|
||||||
from invokeai.app.util.misc import get_timestamp
|
from invokeai.app.util.misc import get_timestamp
|
||||||
|
|
||||||
|
|
||||||
class EventServiceBase:
|
class EventServiceBase:
|
||||||
session_event: str = "session_event"
|
queue_event: str = "queue_event"
|
||||||
|
|
||||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||||
|
|
||||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __emit_session_event(self, event_name: str, payload: dict) -> None:
|
def __emit_queue_event(self, event_name: str, payload: dict) -> None:
|
||||||
|
"""Queue events are emitted to a room with queue_id as the room name"""
|
||||||
payload["timestamp"] = get_timestamp()
|
payload["timestamp"] = get_timestamp()
|
||||||
self.dispatch(
|
self.dispatch(
|
||||||
event_name=EventServiceBase.session_event,
|
event_name=EventServiceBase.queue_event,
|
||||||
payload=dict(event=event_name, data=payload),
|
payload=dict(event=event_name, data=payload),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -26,6 +28,8 @@ class EventServiceBase:
|
|||||||
# This will make them easier to integrate until we find a schema generator.
|
# This will make them easier to integrate until we find a schema generator.
|
||||||
def emit_generator_progress(
|
def emit_generator_progress(
|
||||||
self,
|
self,
|
||||||
|
queue_id: str,
|
||||||
|
queue_item_id: str,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
node: dict,
|
node: dict,
|
||||||
source_node_id: str,
|
source_node_id: str,
|
||||||
@ -35,11 +39,13 @@ class EventServiceBase:
|
|||||||
total_steps: int,
|
total_steps: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emitted when there is generation progress"""
|
"""Emitted when there is generation progress"""
|
||||||
self.__emit_session_event(
|
self.__emit_queue_event(
|
||||||
event_name="generator_progress",
|
event_name="generator_progress",
|
||||||
payload=dict(
|
payload=dict(
|
||||||
|
queue_id=queue_id,
|
||||||
|
queue_item_id=queue_item_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
node=node,
|
node_id=node.get("id"),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
progress_image=progress_image.dict() if progress_image is not None else None,
|
progress_image=progress_image.dict() if progress_image is not None else None,
|
||||||
step=step,
|
step=step,
|
||||||
@ -50,15 +56,19 @@ class EventServiceBase:
|
|||||||
|
|
||||||
def emit_invocation_complete(
|
def emit_invocation_complete(
|
||||||
self,
|
self,
|
||||||
|
queue_id: str,
|
||||||
|
queue_item_id: str,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
result: dict,
|
result: dict,
|
||||||
node: dict,
|
node: dict,
|
||||||
source_node_id: str,
|
source_node_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emitted when an invocation has completed"""
|
"""Emitted when an invocation has completed"""
|
||||||
self.__emit_session_event(
|
self.__emit_queue_event(
|
||||||
event_name="invocation_complete",
|
event_name="invocation_complete",
|
||||||
payload=dict(
|
payload=dict(
|
||||||
|
queue_id=queue_id,
|
||||||
|
queue_item_id=queue_item_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
node=node,
|
node=node,
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
@ -68,6 +78,8 @@ class EventServiceBase:
|
|||||||
|
|
||||||
def emit_invocation_error(
|
def emit_invocation_error(
|
||||||
self,
|
self,
|
||||||
|
queue_id: str,
|
||||||
|
queue_item_id: str,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
node: dict,
|
node: dict,
|
||||||
source_node_id: str,
|
source_node_id: str,
|
||||||
@ -75,9 +87,11 @@ class EventServiceBase:
|
|||||||
error: str,
|
error: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emitted when an invocation has completed"""
|
"""Emitted when an invocation has completed"""
|
||||||
self.__emit_session_event(
|
self.__emit_queue_event(
|
||||||
event_name="invocation_error",
|
event_name="invocation_error",
|
||||||
payload=dict(
|
payload=dict(
|
||||||
|
queue_id=queue_id,
|
||||||
|
queue_item_id=queue_item_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
node=node,
|
node=node,
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
@ -86,28 +100,36 @@ class EventServiceBase:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_invocation_started(self, graph_execution_state_id: str, node: dict, source_node_id: str) -> None:
|
def emit_invocation_started(
|
||||||
|
self, queue_id: str, queue_item_id: str, graph_execution_state_id: str, node: dict, source_node_id: str
|
||||||
|
) -> None:
|
||||||
"""Emitted when an invocation has started"""
|
"""Emitted when an invocation has started"""
|
||||||
self.__emit_session_event(
|
self.__emit_queue_event(
|
||||||
event_name="invocation_started",
|
event_name="invocation_started",
|
||||||
payload=dict(
|
payload=dict(
|
||||||
|
queue_id=queue_id,
|
||||||
|
queue_item_id=queue_item_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
node=node,
|
node=node,
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_graph_execution_complete(self, graph_execution_state_id: str) -> None:
|
def emit_graph_execution_complete(self, queue_id: str, queue_item_id: str, graph_execution_state_id: str) -> None:
|
||||||
"""Emitted when a session has completed all invocations"""
|
"""Emitted when a session has completed all invocations"""
|
||||||
self.__emit_session_event(
|
self.__emit_queue_event(
|
||||||
event_name="graph_execution_state_complete",
|
event_name="graph_execution_state_complete",
|
||||||
payload=dict(
|
payload=dict(
|
||||||
|
queue_id=queue_id,
|
||||||
|
queue_item_id=queue_item_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_model_load_started(
|
def emit_model_load_started(
|
||||||
self,
|
self,
|
||||||
|
queue_id: str,
|
||||||
|
queue_item_id: str,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
@ -115,9 +137,11 @@ class EventServiceBase:
|
|||||||
submodel: SubModelType,
|
submodel: SubModelType,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emitted when a model is requested"""
|
"""Emitted when a model is requested"""
|
||||||
self.__emit_session_event(
|
self.__emit_queue_event(
|
||||||
event_name="model_load_started",
|
event_name="model_load_started",
|
||||||
payload=dict(
|
payload=dict(
|
||||||
|
queue_id=queue_id,
|
||||||
|
queue_item_id=queue_item_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
@ -128,6 +152,8 @@ class EventServiceBase:
|
|||||||
|
|
||||||
def emit_model_load_completed(
|
def emit_model_load_completed(
|
||||||
self,
|
self,
|
||||||
|
queue_id: str,
|
||||||
|
queue_item_id: str,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
@ -136,9 +162,11 @@ class EventServiceBase:
|
|||||||
model_info: ModelInfo,
|
model_info: ModelInfo,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||||
self.__emit_session_event(
|
self.__emit_queue_event(
|
||||||
event_name="model_load_completed",
|
event_name="model_load_completed",
|
||||||
payload=dict(
|
payload=dict(
|
||||||
|
queue_id=queue_id,
|
||||||
|
queue_item_id=queue_item_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
@ -152,14 +180,18 @@ class EventServiceBase:
|
|||||||
|
|
||||||
def emit_session_retrieval_error(
|
def emit_session_retrieval_error(
|
||||||
self,
|
self,
|
||||||
|
queue_id: str,
|
||||||
|
queue_item_id: str,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
error_type: str,
|
error_type: str,
|
||||||
error: str,
|
error: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emitted when session retrieval fails"""
|
"""Emitted when session retrieval fails"""
|
||||||
self.__emit_session_event(
|
self.__emit_queue_event(
|
||||||
event_name="session_retrieval_error",
|
event_name="session_retrieval_error",
|
||||||
payload=dict(
|
payload=dict(
|
||||||
|
queue_id=queue_id,
|
||||||
|
queue_item_id=queue_item_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
error_type=error_type,
|
error_type=error_type,
|
||||||
error=error,
|
error=error,
|
||||||
@ -168,18 +200,74 @@ class EventServiceBase:
|
|||||||
|
|
||||||
def emit_invocation_retrieval_error(
|
def emit_invocation_retrieval_error(
|
||||||
self,
|
self,
|
||||||
|
queue_id: str,
|
||||||
|
queue_item_id: str,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
error_type: str,
|
error_type: str,
|
||||||
error: str,
|
error: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emitted when invocation retrieval fails"""
|
"""Emitted when invocation retrieval fails"""
|
||||||
self.__emit_session_event(
|
self.__emit_queue_event(
|
||||||
event_name="invocation_retrieval_error",
|
event_name="invocation_retrieval_error",
|
||||||
payload=dict(
|
payload=dict(
|
||||||
|
queue_id=queue_id,
|
||||||
|
queue_item_id=queue_item_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
error_type=error_type,
|
error_type=error_type,
|
||||||
error=error,
|
error=error,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def emit_session_canceled(
|
||||||
|
self,
|
||||||
|
queue_id: str,
|
||||||
|
queue_item_id: str,
|
||||||
|
graph_execution_state_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Emitted when a session is canceled"""
|
||||||
|
self.__emit_queue_event(
|
||||||
|
event_name="session_canceled",
|
||||||
|
payload=dict(
|
||||||
|
queue_id=queue_id,
|
||||||
|
queue_item_id=queue_item_id,
|
||||||
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def emit_queue_item_status_changed(self, session_queue_item: SessionQueueItem) -> None:
|
||||||
|
"""Emitted when a queue item's status changes"""
|
||||||
|
self.__emit_queue_event(
|
||||||
|
event_name="queue_item_status_changed",
|
||||||
|
payload=dict(
|
||||||
|
queue_id=session_queue_item.queue_id,
|
||||||
|
queue_item_id=session_queue_item.item_id,
|
||||||
|
status=session_queue_item.status,
|
||||||
|
batch_id=session_queue_item.batch_id,
|
||||||
|
session_id=session_queue_item.session_id,
|
||||||
|
error=session_queue_item.error,
|
||||||
|
created_at=str(session_queue_item.created_at) if session_queue_item.created_at else None,
|
||||||
|
updated_at=str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
|
||||||
|
started_at=str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
||||||
|
completed_at=str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
|
||||||
|
"""Emitted when a batch is enqueued"""
|
||||||
|
self.__emit_queue_event(
|
||||||
|
event_name="batch_enqueued",
|
||||||
|
payload=dict(
|
||||||
|
queue_id=enqueue_result.queue_id,
|
||||||
|
batch_id=enqueue_result.batch.batch_id,
|
||||||
|
enqueued=enqueue_result.enqueued,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def emit_queue_cleared(self, queue_id: str) -> None:
|
||||||
|
"""Emitted when the queue is cleared"""
|
||||||
|
self.__emit_queue_event(
|
||||||
|
event_name="queue_cleared",
|
||||||
|
payload=dict(queue_id=queue_id),
|
||||||
|
)
|
||||||
|
@ -2,13 +2,14 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
import uuid
|
from typing import Annotated, Any, Optional, Union, cast, get_args, get_origin, get_type_hints
|
||||||
from typing import Annotated, Any, Optional, Union, get_args, get_origin, get_type_hints
|
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from pydantic import BaseModel, root_validator, validator
|
from pydantic import BaseModel, root_validator, validator
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
# Importing * is bad karma but needed here for node detection
|
# Importing * is bad karma but needed here for node detection
|
||||||
from ..invocations import * # noqa: F401 F403
|
from ..invocations import * # noqa: F401 F403
|
||||||
from ..invocations.baseinvocation import (
|
from ..invocations.baseinvocation import (
|
||||||
@ -137,19 +138,31 @@ def are_connections_compatible(
|
|||||||
return are_connection_types_compatible(from_node_field, to_node_field)
|
return are_connection_types_compatible(from_node_field, to_node_field)
|
||||||
|
|
||||||
|
|
||||||
class NodeAlreadyInGraphError(Exception):
|
class NodeAlreadyInGraphError(ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InvalidEdgeError(Exception):
|
class InvalidEdgeError(ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class NodeNotFoundError(Exception):
|
class NodeNotFoundError(ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class NodeAlreadyExecutedError(Exception):
|
class NodeAlreadyExecutedError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DuplicateNodeIdError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NodeFieldNotFoundError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NodeIdMismatchError(ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -227,7 +240,7 @@ InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()]
|
|||||||
|
|
||||||
|
|
||||||
class Graph(BaseModel):
|
class Graph(BaseModel):
|
||||||
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
|
id: str = Field(description="The id of this graph", default_factory=uuid_string)
|
||||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||||
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
||||||
description="The nodes in this graph", default_factory=dict
|
description="The nodes in this graph", default_factory=dict
|
||||||
@ -237,6 +250,59 @@ class Graph(BaseModel):
|
|||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
def validate_nodes_and_edges(cls, values):
|
||||||
|
"""Validates that all edges match nodes in the graph"""
|
||||||
|
nodes = cast(Optional[dict[str, BaseInvocation]], values.get("nodes"))
|
||||||
|
edges = cast(Optional[list[Edge]], values.get("edges"))
|
||||||
|
|
||||||
|
if nodes is not None:
|
||||||
|
# Validate that all node ids are unique
|
||||||
|
node_ids = [n.id for n in nodes.values()]
|
||||||
|
duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2])
|
||||||
|
if duplicate_node_ids:
|
||||||
|
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
|
||||||
|
|
||||||
|
# Validate that all node ids match the keys in the nodes dict
|
||||||
|
for k, v in nodes.items():
|
||||||
|
if k != v.id:
|
||||||
|
raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}")
|
||||||
|
|
||||||
|
if edges is not None and nodes is not None:
|
||||||
|
# Validate that all edges match nodes in the graph
|
||||||
|
node_ids = set([e.source.node_id for e in edges] + [e.destination.node_id for e in edges])
|
||||||
|
missing_node_ids = [node_id for node_id in node_ids if node_id not in nodes]
|
||||||
|
if missing_node_ids:
|
||||||
|
raise NodeNotFoundError(
|
||||||
|
f"All edges must reference nodes in the graph, missing nodes: {missing_node_ids}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate that all edge fields match node fields in the graph
|
||||||
|
for edge in edges:
|
||||||
|
source_node = nodes.get(edge.source.node_id, None)
|
||||||
|
if source_node is None:
|
||||||
|
raise NodeFieldNotFoundError(f"Edge source node {edge.source.node_id} does not exist in the graph")
|
||||||
|
|
||||||
|
destination_node = nodes.get(edge.destination.node_id, None)
|
||||||
|
if destination_node is None:
|
||||||
|
raise NodeFieldNotFoundError(
|
||||||
|
f"Edge destination node {edge.destination.node_id} does not exist in the graph"
|
||||||
|
)
|
||||||
|
|
||||||
|
# output fields are not on the node object directly, they are on the output type
|
||||||
|
if edge.source.field not in source_node.get_output_type().__fields__:
|
||||||
|
raise NodeFieldNotFoundError(
|
||||||
|
f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# input fields are on the node
|
||||||
|
if edge.destination.field not in destination_node.__fields__:
|
||||||
|
raise NodeFieldNotFoundError(
|
||||||
|
f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
def add_node(self, node: BaseInvocation) -> None:
|
def add_node(self, node: BaseInvocation) -> None:
|
||||||
"""Adds a node to a graph
|
"""Adds a node to a graph
|
||||||
|
|
||||||
@ -697,8 +763,7 @@ class Graph(BaseModel):
|
|||||||
class GraphExecutionState(BaseModel):
|
class GraphExecutionState(BaseModel):
|
||||||
"""Tracks the state of a graph execution"""
|
"""Tracks the state of a graph execution"""
|
||||||
|
|
||||||
id: str = Field(description="The id of the execution state", default_factory=lambda: uuid.uuid4().__str__())
|
id: str = Field(description="The id of the execution state", default_factory=uuid_string)
|
||||||
|
|
||||||
# TODO: Store a reference to the graph instead of the actual graph?
|
# TODO: Store a reference to the graph instead of the actual graph?
|
||||||
graph: Graph = Field(description="The graph being executed")
|
graph: Graph = Field(description="The graph being executed")
|
||||||
|
|
||||||
@ -847,7 +912,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
new_node = copy.deepcopy(node)
|
new_node = copy.deepcopy(node)
|
||||||
|
|
||||||
# Create the node id (use a random uuid)
|
# Create the node id (use a random uuid)
|
||||||
new_node.id = str(uuid.uuid4())
|
new_node.id = uuid_string()
|
||||||
|
|
||||||
# Set the iteration index for iteration invocations
|
# Set the iteration index for iteration invocations
|
||||||
if isinstance(new_node, IterateInvocation):
|
if isinstance(new_node, IterateInvocation):
|
||||||
@ -1082,7 +1147,7 @@ class ExposedNodeOutput(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class LibraryGraph(BaseModel):
|
class LibraryGraph(BaseModel):
|
||||||
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
|
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid_string)
|
||||||
graph: Graph = Field(description="The graph")
|
graph: Graph = Field(description="The graph")
|
||||||
name: str = Field(description="The name of the graph")
|
name: str = Field(description="The name of the graph")
|
||||||
description: str = Field(description="The description of the graph")
|
description: str = Field(description="The description of the graph")
|
||||||
|
@ -148,24 +148,20 @@ class ImageRecordStorageBase(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||||
_filename: str
|
|
||||||
_conn: sqlite3.Connection
|
_conn: sqlite3.Connection
|
||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_lock: threading.Lock
|
_lock: threading.Lock
|
||||||
|
|
||||||
def __init__(self, filename: str) -> None:
|
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._filename = filename
|
self._conn = conn
|
||||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
|
||||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||||
self._conn.row_factory = sqlite3.Row
|
self._conn.row_factory = sqlite3.Row
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
self._lock = threading.Lock()
|
self._lock = lock
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
# Enable foreign keys
|
|
||||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
|
||||||
self._create_tables()
|
self._create_tables()
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
finally:
|
finally:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Callable, Optional
|
||||||
|
|
||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
|
|
||||||
@ -38,6 +38,27 @@ if TYPE_CHECKING:
|
|||||||
class ImageServiceABC(ABC):
|
class ImageServiceABC(ABC):
|
||||||
"""High-level service for image management."""
|
"""High-level service for image management."""
|
||||||
|
|
||||||
|
_on_changed_callbacks: list[Callable[[ImageDTO], None]]
|
||||||
|
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
|
||||||
|
"""Register a callback for when an item is changed"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||||
|
"""Register a callback for when an item is deleted"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _on_changed(self, item: ImageDTO) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _on_deleted(self, item_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create(
|
def create(
|
||||||
self,
|
self,
|
||||||
@ -159,6 +180,24 @@ class ImageServiceDependencies:
|
|||||||
|
|
||||||
class ImageService(ImageServiceABC):
|
class ImageService(ImageServiceABC):
|
||||||
_services: ImageServiceDependencies
|
_services: ImageServiceDependencies
|
||||||
|
_on_changed_callbacks: list[Callable[[ImageDTO], None]] = list()
|
||||||
|
_on_deleted_callbacks: list[Callable[[str], None]] = list()
|
||||||
|
|
||||||
|
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
|
||||||
|
"""Register a callback for when an item is changed"""
|
||||||
|
self._on_changed_callbacks.append(on_changed)
|
||||||
|
|
||||||
|
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||||
|
"""Register a callback for when an item is deleted"""
|
||||||
|
self._on_deleted_callbacks.append(on_deleted)
|
||||||
|
|
||||||
|
def _on_changed(self, item: ImageDTO) -> None:
|
||||||
|
for callback in self._on_changed_callbacks:
|
||||||
|
callback(item)
|
||||||
|
|
||||||
|
def _on_deleted(self, item_id: str) -> None:
|
||||||
|
for callback in self._on_deleted_callbacks:
|
||||||
|
callback(item_id)
|
||||||
|
|
||||||
def __init__(self, services: ImageServiceDependencies):
|
def __init__(self, services: ImageServiceDependencies):
|
||||||
self._services = services
|
self._services = services
|
||||||
@ -217,6 +256,7 @@ class ImageService(ImageServiceABC):
|
|||||||
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, workflow=workflow)
|
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, workflow=workflow)
|
||||||
image_dto = self.get_dto(image_name)
|
image_dto = self.get_dto(image_name)
|
||||||
|
|
||||||
|
self._on_changed(image_dto)
|
||||||
return image_dto
|
return image_dto
|
||||||
except ImageRecordSaveException:
|
except ImageRecordSaveException:
|
||||||
self._services.logger.error("Failed to save image record")
|
self._services.logger.error("Failed to save image record")
|
||||||
@ -235,7 +275,9 @@ class ImageService(ImageServiceABC):
|
|||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
try:
|
try:
|
||||||
self._services.image_records.update(image_name, changes)
|
self._services.image_records.update(image_name, changes)
|
||||||
return self.get_dto(image_name)
|
image_dto = self.get_dto(image_name)
|
||||||
|
self._on_changed(image_dto)
|
||||||
|
return image_dto
|
||||||
except ImageRecordSaveException:
|
except ImageRecordSaveException:
|
||||||
self._services.logger.error("Failed to update image record")
|
self._services.logger.error("Failed to update image record")
|
||||||
raise
|
raise
|
||||||
@ -374,6 +416,7 @@ class ImageService(ImageServiceABC):
|
|||||||
try:
|
try:
|
||||||
self._services.image_files.delete(image_name)
|
self._services.image_files.delete(image_name)
|
||||||
self._services.image_records.delete(image_name)
|
self._services.image_records.delete(image_name)
|
||||||
|
self._on_deleted(image_name)
|
||||||
except ImageRecordDeleteException:
|
except ImageRecordDeleteException:
|
||||||
self._services.logger.error("Failed to delete image record")
|
self._services.logger.error("Failed to delete image record")
|
||||||
raise
|
raise
|
||||||
@ -390,6 +433,8 @@ class ImageService(ImageServiceABC):
|
|||||||
for image_name in image_names:
|
for image_name in image_names:
|
||||||
self._services.image_files.delete(image_name)
|
self._services.image_files.delete(image_name)
|
||||||
self._services.image_records.delete_many(image_names)
|
self._services.image_records.delete_many(image_names)
|
||||||
|
for image_name in image_names:
|
||||||
|
self._on_deleted(image_name)
|
||||||
except ImageRecordDeleteException:
|
except ImageRecordDeleteException:
|
||||||
self._services.logger.error("Failed to delete image records")
|
self._services.logger.error("Failed to delete image records")
|
||||||
raise
|
raise
|
||||||
@ -406,6 +451,7 @@ class ImageService(ImageServiceABC):
|
|||||||
count = len(image_names)
|
count = len(image_names)
|
||||||
for image_name in image_names:
|
for image_name in image_names:
|
||||||
self._services.image_files.delete(image_name)
|
self._services.image_files.delete(image_name)
|
||||||
|
self._on_deleted(image_name)
|
||||||
return count
|
return count
|
||||||
except ImageRecordDeleteException:
|
except ImageRecordDeleteException:
|
||||||
self._services.logger.error("Failed to delete image records")
|
self._services.logger.error("Failed to delete image records")
|
||||||
|
0
invokeai/app/services/invocation_cache/__init__.py
Normal file
0
invokeai/app/services/invocation_cache/__init__.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||||
|
|
||||||
|
|
||||||
|
class InvocationCacheBase(ABC):
|
||||||
|
"""Base class for invocation caches."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
|
||||||
|
"""Retrieves and invocation output from the cache"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(self, value: BaseInvocationOutput) -> None:
|
||||||
|
"""Stores an invocation output in the cache"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, key: Union[int, str]) -> None:
|
||||||
|
"""Deleted an invocation output from the cache"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clears the cache"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def create_key(cls, value: BaseInvocation) -> Union[int, str]:
|
||||||
|
"""Creates the cache key for an invocation"""
|
||||||
|
pass
|
@ -0,0 +1,70 @@
|
|||||||
|
from queue import Queue
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||||
|
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
|
||||||
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryInvocationCache(InvocationCacheBase):
|
||||||
|
__cache: dict[Union[int, str], tuple[BaseInvocationOutput, str]]
|
||||||
|
__max_cache_size: int
|
||||||
|
__cache_ids: Queue
|
||||||
|
__invoker: Invoker
|
||||||
|
|
||||||
|
def __init__(self, max_cache_size: int = 512) -> None:
|
||||||
|
self.__cache = dict()
|
||||||
|
self.__max_cache_size = max_cache_size
|
||||||
|
self.__cache_ids = Queue()
|
||||||
|
|
||||||
|
def start(self, invoker: Invoker) -> None:
|
||||||
|
self.__invoker = invoker
|
||||||
|
self.__invoker.services.images.on_deleted(self.delete_by_match)
|
||||||
|
|
||||||
|
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
|
||||||
|
if self.__max_cache_size == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
item = self.__cache.get(key, None)
|
||||||
|
if item is not None:
|
||||||
|
return item[0]
|
||||||
|
|
||||||
|
def save(self, value: BaseInvocationOutput) -> None:
|
||||||
|
if self.__max_cache_size == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
value_json = value.json(exclude={"id"})
|
||||||
|
key = hash(value_json)
|
||||||
|
|
||||||
|
if key not in self.__cache:
|
||||||
|
self.__cache[key] = (value, value_json)
|
||||||
|
self.__cache_ids.put(key)
|
||||||
|
if self.__cache_ids.qsize() > self.__max_cache_size:
|
||||||
|
try:
|
||||||
|
self.__cache.pop(self.__cache_ids.get())
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def delete(self, key: Union[int, str]) -> None:
|
||||||
|
if self.__max_cache_size == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if key in self.__cache:
|
||||||
|
del self.__cache[key]
|
||||||
|
|
||||||
|
def delete_by_match(self, to_match: str) -> None:
|
||||||
|
to_delete = []
|
||||||
|
for name, item in self.__cache.items():
|
||||||
|
if to_match in item[1]:
|
||||||
|
to_delete.append(name)
|
||||||
|
for key in to_delete:
|
||||||
|
self.delete(key)
|
||||||
|
|
||||||
|
def clear(self, *args, **kwargs) -> None:
|
||||||
|
self.__cache.clear()
|
||||||
|
self.__cache_ids = Queue()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_key(cls, value: BaseInvocation) -> Union[int, str]:
|
||||||
|
return hash(value.json(exclude={"id"}))
|
@ -11,6 +11,10 @@ from pydantic import BaseModel, Field
|
|||||||
class InvocationQueueItem(BaseModel):
|
class InvocationQueueItem(BaseModel):
|
||||||
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
||||||
invocation_id: str = Field(description="The ID of the node being invoked")
|
invocation_id: str = Field(description="The ID of the node being invoked")
|
||||||
|
session_queue_id: str = Field(description="The ID of the session queue from which this invocation queue item came")
|
||||||
|
session_queue_item_id: str = Field(
|
||||||
|
description="The ID of session queue item from which this invocation queue item came"
|
||||||
|
)
|
||||||
invoke_all: bool = Field(default=False)
|
invoke_all: bool = Field(default=False)
|
||||||
timestamp: float = Field(default_factory=time.time)
|
timestamp: float = Field(default_factory=time.time)
|
||||||
|
|
||||||
|
@ -12,12 +12,15 @@ if TYPE_CHECKING:
|
|||||||
from invokeai.app.services.events import EventServiceBase
|
from invokeai.app.services.events import EventServiceBase
|
||||||
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
||||||
from invokeai.app.services.images import ImageServiceABC
|
from invokeai.app.services.images import ImageServiceABC
|
||||||
|
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
|
||||||
from invokeai.app.services.invocation_queue import InvocationQueueABC
|
from invokeai.app.services.invocation_queue import InvocationQueueABC
|
||||||
from invokeai.app.services.invocation_stats import InvocationStatsServiceBase
|
from invokeai.app.services.invocation_stats import InvocationStatsServiceBase
|
||||||
from invokeai.app.services.invoker import InvocationProcessorABC
|
from invokeai.app.services.invoker import InvocationProcessorABC
|
||||||
from invokeai.app.services.item_storage import ItemStorageABC
|
from invokeai.app.services.item_storage import ItemStorageABC
|
||||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||||
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
||||||
|
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
|
||||||
|
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
||||||
|
|
||||||
|
|
||||||
class InvocationServices:
|
class InvocationServices:
|
||||||
@ -28,8 +31,8 @@ class InvocationServices:
|
|||||||
boards: "BoardServiceABC"
|
boards: "BoardServiceABC"
|
||||||
configuration: "InvokeAIAppConfig"
|
configuration: "InvokeAIAppConfig"
|
||||||
events: "EventServiceBase"
|
events: "EventServiceBase"
|
||||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
|
graph_execution_manager: "ItemStorageABC[GraphExecutionState]"
|
||||||
graph_library: "ItemStorageABC"["LibraryGraph"]
|
graph_library: "ItemStorageABC[LibraryGraph]"
|
||||||
images: "ImageServiceABC"
|
images: "ImageServiceABC"
|
||||||
latents: "LatentsStorageBase"
|
latents: "LatentsStorageBase"
|
||||||
logger: "Logger"
|
logger: "Logger"
|
||||||
@ -37,6 +40,9 @@ class InvocationServices:
|
|||||||
processor: "InvocationProcessorABC"
|
processor: "InvocationProcessorABC"
|
||||||
performance_statistics: "InvocationStatsServiceBase"
|
performance_statistics: "InvocationStatsServiceBase"
|
||||||
queue: "InvocationQueueABC"
|
queue: "InvocationQueueABC"
|
||||||
|
session_queue: "SessionQueueBase"
|
||||||
|
session_processor: "SessionProcessorBase"
|
||||||
|
invocation_cache: "InvocationCacheBase"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -44,8 +50,8 @@ class InvocationServices:
|
|||||||
boards: "BoardServiceABC",
|
boards: "BoardServiceABC",
|
||||||
configuration: "InvokeAIAppConfig",
|
configuration: "InvokeAIAppConfig",
|
||||||
events: "EventServiceBase",
|
events: "EventServiceBase",
|
||||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
|
graph_execution_manager: "ItemStorageABC[GraphExecutionState]",
|
||||||
graph_library: "ItemStorageABC"["LibraryGraph"],
|
graph_library: "ItemStorageABC[LibraryGraph]",
|
||||||
images: "ImageServiceABC",
|
images: "ImageServiceABC",
|
||||||
latents: "LatentsStorageBase",
|
latents: "LatentsStorageBase",
|
||||||
logger: "Logger",
|
logger: "Logger",
|
||||||
@ -53,10 +59,12 @@ class InvocationServices:
|
|||||||
processor: "InvocationProcessorABC",
|
processor: "InvocationProcessorABC",
|
||||||
performance_statistics: "InvocationStatsServiceBase",
|
performance_statistics: "InvocationStatsServiceBase",
|
||||||
queue: "InvocationQueueABC",
|
queue: "InvocationQueueABC",
|
||||||
|
session_queue: "SessionQueueBase",
|
||||||
|
session_processor: "SessionProcessorBase",
|
||||||
|
invocation_cache: "InvocationCacheBase",
|
||||||
):
|
):
|
||||||
self.board_images = board_images
|
self.board_images = board_images
|
||||||
self.boards = boards
|
self.boards = boards
|
||||||
self.boards = boards
|
|
||||||
self.configuration = configuration
|
self.configuration = configuration
|
||||||
self.events = events
|
self.events = events
|
||||||
self.graph_execution_manager = graph_execution_manager
|
self.graph_execution_manager = graph_execution_manager
|
||||||
@ -68,3 +76,6 @@ class InvocationServices:
|
|||||||
self.processor = processor
|
self.processor = processor
|
||||||
self.performance_statistics = performance_statistics
|
self.performance_statistics = performance_statistics
|
||||||
self.queue = queue
|
self.queue = queue
|
||||||
|
self.session_queue = session_queue
|
||||||
|
self.session_processor = session_processor
|
||||||
|
self.invocation_cache = invocation_cache
|
||||||
|
@ -17,7 +17,9 @@ class Invoker:
|
|||||||
self.services = services
|
self.services = services
|
||||||
self._start()
|
self._start()
|
||||||
|
|
||||||
def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> Optional[str]:
|
def invoke(
|
||||||
|
self, queue_id: str, queue_item_id: str, graph_execution_state: GraphExecutionState, invoke_all: bool = False
|
||||||
|
) -> Optional[str]:
|
||||||
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
||||||
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
||||||
|
|
||||||
@ -32,7 +34,8 @@ class Invoker:
|
|||||||
# Queue the invocation
|
# Queue the invocation
|
||||||
self.services.queue.put(
|
self.services.queue.put(
|
||||||
InvocationQueueItem(
|
InvocationQueueItem(
|
||||||
# session_id = session.id,
|
session_queue_item_id=queue_item_id,
|
||||||
|
session_queue_id=queue_id,
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
invocation_id=invocation.id,
|
invocation_id=invocation.id,
|
||||||
invoke_all=invoke_all,
|
invoke_all=invoke_all,
|
||||||
|
@ -525,7 +525,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
|
|
||||||
def _emit_load_event(
|
def _emit_load_event(
|
||||||
self,
|
self,
|
||||||
context,
|
context: InvocationContext,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
@ -537,6 +537,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
|
|
||||||
if model_info:
|
if model_info:
|
||||||
context.services.events.emit_model_load_completed(
|
context.services.events.emit_model_load_completed(
|
||||||
|
queue_id=context.queue_id,
|
||||||
|
queue_item_id=context.queue_item_id,
|
||||||
graph_execution_state_id=context.graph_execution_state_id,
|
graph_execution_state_id=context.graph_execution_state_id,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
@ -546,6 +548,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
context.services.events.emit_model_load_started(
|
context.services.events.emit_model_load_started(
|
||||||
|
queue_id=context.queue_id,
|
||||||
|
queue_item_id=context.queue_item_id,
|
||||||
graph_execution_state_id=context.graph_execution_state_id,
|
graph_execution_state_id=context.graph_execution_state_id,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from threading import BoundedSemaphore, Event, Thread
|
from threading import BoundedSemaphore, Event, Thread
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
@ -37,10 +38,11 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
try:
|
try:
|
||||||
self.__threadLimit.acquire()
|
self.__threadLimit.acquire()
|
||||||
statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
|
statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
|
||||||
|
queue_item: Optional[InvocationQueueItem] = None
|
||||||
|
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
queue_item = self.__invoker.services.queue.get()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e)
|
self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e)
|
||||||
|
|
||||||
@ -48,7 +50,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
# do not hammer the queue
|
# do not hammer the queue
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
graph_execution_state = self.__invoker.services.graph_execution_manager.get(
|
graph_execution_state = self.__invoker.services.graph_execution_manager.get(
|
||||||
queue_item.graph_execution_state_id
|
queue_item.graph_execution_state_id
|
||||||
@ -56,6 +57,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
|
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
|
||||||
self.__invoker.services.events.emit_session_retrieval_error(
|
self.__invoker.services.events.emit_session_retrieval_error(
|
||||||
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
|
queue_id=queue_item.session_queue_id,
|
||||||
graph_execution_state_id=queue_item.graph_execution_state_id,
|
graph_execution_state_id=queue_item.graph_execution_state_id,
|
||||||
error_type=e.__class__.__name__,
|
error_type=e.__class__.__name__,
|
||||||
error=traceback.format_exc(),
|
error=traceback.format_exc(),
|
||||||
@ -67,6 +70,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
|
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
|
||||||
self.__invoker.services.events.emit_invocation_retrieval_error(
|
self.__invoker.services.events.emit_invocation_retrieval_error(
|
||||||
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
|
queue_id=queue_item.session_queue_id,
|
||||||
graph_execution_state_id=queue_item.graph_execution_state_id,
|
graph_execution_state_id=queue_item.graph_execution_state_id,
|
||||||
node_id=queue_item.invocation_id,
|
node_id=queue_item.invocation_id,
|
||||||
error_type=e.__class__.__name__,
|
error_type=e.__class__.__name__,
|
||||||
@ -79,6 +84,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
|
|
||||||
# Send starting event
|
# Send starting event
|
||||||
self.__invoker.services.events.emit_invocation_started(
|
self.__invoker.services.events.emit_invocation_started(
|
||||||
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
|
queue_id=queue_item.session_queue_id,
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
node=invocation.dict(),
|
node=invocation.dict(),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
@ -89,13 +96,16 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
graph_id = graph_execution_state.id
|
graph_id = graph_execution_state.id
|
||||||
model_manager = self.__invoker.services.model_manager
|
model_manager = self.__invoker.services.model_manager
|
||||||
with statistics.collect_stats(invocation, graph_id, model_manager):
|
with statistics.collect_stats(invocation, graph_id, model_manager):
|
||||||
# use the internal invoke_internal(), which wraps the node's invoke() method in
|
# use the internal invoke_internal(), which wraps the node's invoke() method,
|
||||||
# this accomodates nodes which require a value, but get it only from a
|
# which handles a few things:
|
||||||
# connection
|
# - nodes that require a value, but get it only from a connection
|
||||||
|
# - referencing the invocation cache instead of executing the node
|
||||||
outputs = invocation.invoke_internal(
|
outputs = invocation.invoke_internal(
|
||||||
InvocationContext(
|
InvocationContext(
|
||||||
services=self.__invoker.services,
|
services=self.__invoker.services,
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
|
queue_id=queue_item.session_queue_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -111,6 +121,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
|
|
||||||
# Send complete event
|
# Send complete event
|
||||||
self.__invoker.services.events.emit_invocation_complete(
|
self.__invoker.services.events.emit_invocation_complete(
|
||||||
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
|
queue_id=queue_item.session_queue_id,
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
node=invocation.dict(),
|
node=invocation.dict(),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
@ -138,6 +150,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
||||||
# Send error event
|
# Send error event
|
||||||
self.__invoker.services.events.emit_invocation_error(
|
self.__invoker.services.events.emit_invocation_error(
|
||||||
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
|
queue_id=queue_item.session_queue_id,
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
node=invocation.dict(),
|
node=invocation.dict(),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
@ -155,10 +169,17 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
is_complete = graph_execution_state.is_complete()
|
is_complete = graph_execution_state.is_complete()
|
||||||
if queue_item.invoke_all and not is_complete:
|
if queue_item.invoke_all and not is_complete:
|
||||||
try:
|
try:
|
||||||
self.__invoker.invoke(graph_execution_state, invoke_all=True)
|
self.__invoker.invoke(
|
||||||
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
|
queue_id=queue_item.session_queue_id,
|
||||||
|
graph_execution_state=graph_execution_state,
|
||||||
|
invoke_all=True,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
||||||
self.__invoker.services.events.emit_invocation_error(
|
self.__invoker.services.events.emit_invocation_error(
|
||||||
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
|
queue_id=queue_item.session_queue_id,
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
node=invocation.dict(),
|
node=invocation.dict(),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
@ -166,7 +187,11 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
error=traceback.format_exc(),
|
error=traceback.format_exc(),
|
||||||
)
|
)
|
||||||
elif is_complete:
|
elif is_complete:
|
||||||
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)
|
self.__invoker.services.events.emit_graph_execution_complete(
|
||||||
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
|
queue_id=queue_item.session_queue_id,
|
||||||
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
|
)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
|
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import uuid
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum, EnumMeta
|
from enum import Enum, EnumMeta
|
||||||
|
|
||||||
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
|
|
||||||
class ResourceType(str, Enum, metaclass=EnumMeta):
|
class ResourceType(str, Enum, metaclass=EnumMeta):
|
||||||
"""Enum for resource types."""
|
"""Enum for resource types."""
|
||||||
@ -25,6 +26,6 @@ class SimpleNameService(NameServiceBase):
|
|||||||
|
|
||||||
# TODO: Add customizable naming schemes
|
# TODO: Add customizable naming schemes
|
||||||
def create_image_name(self) -> str:
|
def create_image_name(self) -> str:
|
||||||
uuid_str = str(uuid.uuid4())
|
uuid_str = uuid_string()
|
||||||
filename = f"{uuid_str}.png"
|
filename = f"{uuid_str}.png"
|
||||||
return filename
|
return filename
|
||||||
|
0
invokeai/app/services/session_processor/__init__.py
Normal file
0
invokeai/app/services/session_processor/__init__.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
||||||
|
|
||||||
|
|
||||||
|
class SessionProcessorBase(ABC):
|
||||||
|
"""
|
||||||
|
Base class for session processor.
|
||||||
|
|
||||||
|
The session processor is responsible for executing sessions. It runs a simple polling loop,
|
||||||
|
checking the session queue for new sessions to execute. It must coordinate with the
|
||||||
|
invocation queue to ensure only one session is executing at a time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def resume(self) -> SessionProcessorStatus:
|
||||||
|
"""Starts or resumes the session processor"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def pause(self) -> SessionProcessorStatus:
|
||||||
|
"""Pauses the session processor"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_status(self) -> SessionProcessorStatus:
|
||||||
|
"""Gets the status of the session processor"""
|
||||||
|
pass
|
@ -0,0 +1,6 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class SessionProcessorStatus(BaseModel):
|
||||||
|
is_started: bool = Field(description="Whether the session processor is started")
|
||||||
|
is_processing: bool = Field(description="Whether a session is being processed")
|
@ -0,0 +1,123 @@
|
|||||||
|
from threading import BoundedSemaphore
|
||||||
|
from threading import Event as ThreadEvent
|
||||||
|
from threading import Thread
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi_events.handlers.local import local_handler
|
||||||
|
from fastapi_events.typing import Event as FastAPIEvent
|
||||||
|
|
||||||
|
from invokeai.app.services.events import EventServiceBase
|
||||||
|
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||||
|
|
||||||
|
from ..invoker import Invoker
|
||||||
|
from .session_processor_base import SessionProcessorBase
|
||||||
|
from .session_processor_common import SessionProcessorStatus
|
||||||
|
|
||||||
|
POLLING_INTERVAL = 1
|
||||||
|
THREAD_LIMIT = 1
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultSessionProcessor(SessionProcessorBase):
|
||||||
|
def start(self, invoker: Invoker) -> None:
|
||||||
|
self.__invoker: Invoker = invoker
|
||||||
|
self.__queue_item: Optional[SessionQueueItem] = None
|
||||||
|
|
||||||
|
self.__resume_event = ThreadEvent()
|
||||||
|
self.__stop_event = ThreadEvent()
|
||||||
|
self.__poll_now_event = ThreadEvent()
|
||||||
|
|
||||||
|
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
|
||||||
|
|
||||||
|
self.__threadLimit = BoundedSemaphore(THREAD_LIMIT)
|
||||||
|
self.__thread = Thread(
|
||||||
|
name="session_processor",
|
||||||
|
target=self.__process,
|
||||||
|
kwargs=dict(
|
||||||
|
stop_event=self.__stop_event, poll_now_event=self.__poll_now_event, resume_event=self.__resume_event
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.__thread.start()
|
||||||
|
|
||||||
|
def stop(self, *args, **kwargs) -> None:
|
||||||
|
self.__stop_event.set()
|
||||||
|
|
||||||
|
def _poll_now(self) -> None:
|
||||||
|
self.__poll_now_event.set()
|
||||||
|
|
||||||
|
async def _on_queue_event(self, event: FastAPIEvent) -> None:
|
||||||
|
event_name = event[1]["event"]
|
||||||
|
|
||||||
|
match event_name:
|
||||||
|
case "graph_execution_state_complete" | "invocation_error" | "session_retrieval_error" | "invocation_retrieval_error":
|
||||||
|
self.__queue_item = None
|
||||||
|
self._poll_now()
|
||||||
|
case "session_canceled" if self.__queue_item is not None and self.__queue_item.session_id == event[1][
|
||||||
|
"data"
|
||||||
|
]["graph_execution_state_id"]:
|
||||||
|
self.__queue_item = None
|
||||||
|
self._poll_now()
|
||||||
|
case "batch_enqueued":
|
||||||
|
self._poll_now()
|
||||||
|
case "queue_cleared":
|
||||||
|
self.__queue_item = None
|
||||||
|
self._poll_now()
|
||||||
|
|
||||||
|
def resume(self) -> SessionProcessorStatus:
|
||||||
|
if not self.__resume_event.is_set():
|
||||||
|
self.__resume_event.set()
|
||||||
|
return self.get_status()
|
||||||
|
|
||||||
|
def pause(self) -> SessionProcessorStatus:
|
||||||
|
if self.__resume_event.is_set():
|
||||||
|
self.__resume_event.clear()
|
||||||
|
return self.get_status()
|
||||||
|
|
||||||
|
def get_status(self) -> SessionProcessorStatus:
|
||||||
|
return SessionProcessorStatus(
|
||||||
|
is_started=self.__resume_event.is_set(),
|
||||||
|
is_processing=self.__queue_item is not None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __process(
|
||||||
|
self,
|
||||||
|
stop_event: ThreadEvent,
|
||||||
|
poll_now_event: ThreadEvent,
|
||||||
|
resume_event: ThreadEvent,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
stop_event.clear()
|
||||||
|
resume_event.set()
|
||||||
|
self.__threadLimit.acquire()
|
||||||
|
queue_item: Optional[SessionQueueItem] = None
|
||||||
|
self.__invoker.services.logger
|
||||||
|
while not stop_event.is_set():
|
||||||
|
poll_now_event.clear()
|
||||||
|
|
||||||
|
# do not dequeue if there is already a session running
|
||||||
|
if self.__queue_item is None and resume_event.is_set():
|
||||||
|
queue_item = self.__invoker.services.session_queue.dequeue()
|
||||||
|
|
||||||
|
if queue_item is not None:
|
||||||
|
self.__invoker.services.logger.debug(f"Executing queue item {queue_item.item_id}")
|
||||||
|
self.__queue_item = queue_item
|
||||||
|
self.__invoker.services.graph_execution_manager.set(queue_item.session)
|
||||||
|
self.__invoker.invoke(
|
||||||
|
queue_item_id=queue_item.item_id,
|
||||||
|
queue_id=queue_item.queue_id,
|
||||||
|
graph_execution_state=queue_item.session,
|
||||||
|
invoke_all=True,
|
||||||
|
)
|
||||||
|
queue_item = None
|
||||||
|
|
||||||
|
if queue_item is None:
|
||||||
|
self.__invoker.services.logger.debug("Waiting for next polling interval or event")
|
||||||
|
poll_now_event.wait(POLLING_INTERVAL)
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
self.__invoker.services.logger.error(f"Error in session processor: {e}")
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
stop_event.clear()
|
||||||
|
poll_now_event.clear()
|
||||||
|
self.__queue_item = None
|
||||||
|
self.__threadLimit.release()
|
0
invokeai/app/services/session_queue/__init__.py
Normal file
0
invokeai/app/services/session_queue/__init__.py
Normal file
112
invokeai/app/services/session_queue/session_queue_base.py
Normal file
112
invokeai/app/services/session_queue/session_queue_base.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from invokeai.app.services.graph import Graph
|
||||||
|
from invokeai.app.services.session_queue.session_queue_common import (
|
||||||
|
QUEUE_ITEM_STATUS,
|
||||||
|
Batch,
|
||||||
|
BatchStatus,
|
||||||
|
CancelByBatchIDsResult,
|
||||||
|
CancelByQueueIDResult,
|
||||||
|
ClearResult,
|
||||||
|
EnqueueBatchResult,
|
||||||
|
EnqueueGraphResult,
|
||||||
|
IsEmptyResult,
|
||||||
|
IsFullResult,
|
||||||
|
PruneResult,
|
||||||
|
SessionQueueItem,
|
||||||
|
SessionQueueItemDTO,
|
||||||
|
SessionQueueStatus,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.shared.models import CursorPaginatedResults
|
||||||
|
|
||||||
|
|
||||||
|
class SessionQueueBase(ABC):
|
||||||
|
"""Base class for session queue"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def dequeue(self) -> Optional[SessionQueueItem]:
|
||||||
|
"""Dequeues the next session queue item."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def enqueue_graph(self, queue_id: str, graph: Graph, prepend: bool) -> EnqueueGraphResult:
|
||||||
|
"""Enqueues a single graph for execution."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
|
||||||
|
"""Enqueues all permutations of a batch for execution."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||||
|
"""Gets the currently-executing session queue item"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||||
|
"""Gets the next session queue item (does not dequeue it)"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def clear(self, queue_id: str) -> ClearResult:
|
||||||
|
"""Deletes all session queue items"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def prune(self, queue_id: str) -> PruneResult:
|
||||||
|
"""Deletes all completed and errored session queue items"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def is_empty(self, queue_id: str) -> IsEmptyResult:
|
||||||
|
"""Checks if the queue is empty"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def is_full(self, queue_id: str) -> IsFullResult:
|
||||||
|
"""Checks if the queue is empty"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
|
||||||
|
"""Gets the status of the queue"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
|
||||||
|
"""Gets the status of a batch"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def cancel_queue_item(self, item_id: str) -> SessionQueueItem:
|
||||||
|
"""Cancels a session queue item"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
|
||||||
|
"""Cancels all queue items with matching batch IDs"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||||
|
"""Cancels all queue items with matching queue ID"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_queue_items(
|
||||||
|
self,
|
||||||
|
queue_id: str,
|
||||||
|
limit: int,
|
||||||
|
priority: int,
|
||||||
|
order_id: Optional[int] = None,
|
||||||
|
status: Optional[QUEUE_ITEM_STATUS] = None,
|
||||||
|
) -> CursorPaginatedResults[SessionQueueItemDTO]:
|
||||||
|
"""Gets a page of session queue items"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_queue_item(self, item_id: str) -> SessionQueueItem:
|
||||||
|
"""Gets a session queue item by ID"""
|
||||||
|
pass
|
428
invokeai/app/services/session_queue/session_queue_common.py
Normal file
428
invokeai/app/services/session_queue/session_queue_common.py
Normal file
@ -0,0 +1,428 @@
|
|||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
from itertools import chain, product
|
||||||
|
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, StrictStr, parse_raw_as, root_validator, validator
|
||||||
|
from pydantic.json import pydantic_encoder
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
|
from invokeai.app.services.graph import Graph, GraphExecutionState, NodeNotFoundError
|
||||||
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
|
# region Errors
|
||||||
|
|
||||||
|
|
||||||
|
class BatchZippedLengthError(ValueError):
|
||||||
|
"""Raise when a batch has items of different lengths."""
|
||||||
|
|
||||||
|
|
||||||
|
class BatchItemsTypeError(TypeError):
|
||||||
|
"""Raise when a batch has items of different types."""
|
||||||
|
|
||||||
|
|
||||||
|
class BatchDuplicateNodeFieldError(ValueError):
|
||||||
|
"""Raise when a batch has duplicate node_path and field_name."""
|
||||||
|
|
||||||
|
|
||||||
|
class TooManySessionsError(ValueError):
|
||||||
|
"""Raise when too many sessions are requested."""
|
||||||
|
|
||||||
|
|
||||||
|
class SessionQueueItemNotFoundError(ValueError):
|
||||||
|
"""Raise when a queue item is not found."""
|
||||||
|
|
||||||
|
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
|
# region Batch
|
||||||
|
|
||||||
|
BatchDataType = Union[
|
||||||
|
StrictStr,
|
||||||
|
float,
|
||||||
|
int,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class NodeFieldValue(BaseModel):
|
||||||
|
node_path: str = Field(description="The node into which this batch data item will be substituted.")
|
||||||
|
field_name: str = Field(description="The field into which this batch data item will be substituted.")
|
||||||
|
value: BatchDataType = Field(description="The value to substitute into the node/field.")
|
||||||
|
|
||||||
|
|
||||||
|
class BatchDatum(BaseModel):
|
||||||
|
node_path: str = Field(description="The node into which this batch data collection will be substituted.")
|
||||||
|
field_name: str = Field(description="The field into which this batch data collection will be substituted.")
|
||||||
|
items: list[BatchDataType] = Field(
|
||||||
|
default_factory=list, description="The list of items to substitute into the node/field."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
BatchDataCollection: TypeAlias = list[list[BatchDatum]]
|
||||||
|
|
||||||
|
|
||||||
|
class Batch(BaseModel):
|
||||||
|
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
|
||||||
|
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
|
||||||
|
graph: Graph = Field(description="The graph to initialize the session with")
|
||||||
|
runs: int = Field(
|
||||||
|
default=1, ge=1, description="Int stating how many times to iterate through all possible batch indices"
|
||||||
|
)
|
||||||
|
|
||||||
|
@validator("data")
|
||||||
|
def validate_lengths(cls, v: Optional[BatchDataCollection]):
|
||||||
|
if v is None:
|
||||||
|
return v
|
||||||
|
for batch_data_list in v:
|
||||||
|
first_item_length = len(batch_data_list[0].items) if batch_data_list and batch_data_list[0].items else 0
|
||||||
|
for i in batch_data_list:
|
||||||
|
if len(i.items) != first_item_length:
|
||||||
|
raise BatchZippedLengthError("Zipped batch items must all have the same length")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("data")
|
||||||
|
def validate_types(cls, v: Optional[BatchDataCollection]):
|
||||||
|
if v is None:
|
||||||
|
return v
|
||||||
|
for batch_data_list in v:
|
||||||
|
for datum in batch_data_list:
|
||||||
|
# Get the type of the first item in the list
|
||||||
|
first_item_type = type(datum.items[0]) if datum.items else None
|
||||||
|
for item in datum.items:
|
||||||
|
if type(item) is not first_item_type:
|
||||||
|
raise BatchItemsTypeError("All items in a batch must have the same type")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("data")
|
||||||
|
def validate_unique_field_mappings(cls, v: Optional[BatchDataCollection]):
|
||||||
|
if v is None:
|
||||||
|
return v
|
||||||
|
paths: set[tuple[str, str]] = set()
|
||||||
|
for batch_data_list in v:
|
||||||
|
for datum in batch_data_list:
|
||||||
|
pair = (datum.node_path, datum.field_name)
|
||||||
|
if pair in paths:
|
||||||
|
raise BatchDuplicateNodeFieldError("Each batch data must have unique node_id and field_name")
|
||||||
|
paths.add(pair)
|
||||||
|
return v
|
||||||
|
|
||||||
|
@root_validator(skip_on_failure=True)
|
||||||
|
def validate_batch_nodes_and_edges(cls, values):
|
||||||
|
batch_data_collection = cast(Optional[BatchDataCollection], values["data"])
|
||||||
|
if batch_data_collection is None:
|
||||||
|
return values
|
||||||
|
graph = cast(Graph, values["graph"])
|
||||||
|
for batch_data_list in batch_data_collection:
|
||||||
|
for batch_data in batch_data_list:
|
||||||
|
try:
|
||||||
|
node = cast(BaseInvocation, graph.get_node(batch_data.node_path))
|
||||||
|
except NodeNotFoundError:
|
||||||
|
raise NodeNotFoundError(f"Node {batch_data.node_path} not found in graph")
|
||||||
|
if batch_data.field_name not in node.__fields__:
|
||||||
|
raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}")
|
||||||
|
return values
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"required": [
|
||||||
|
"graph",
|
||||||
|
"runs",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# endregion Batch
|
||||||
|
|
||||||
|
|
||||||
|
# region Queue Items
|
||||||
|
|
||||||
|
DEFAULT_QUEUE_ID = "default"
|
||||||
|
|
||||||
|
QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"]
|
||||||
|
|
||||||
|
|
||||||
|
def get_field_values(queue_item_dict: dict) -> Optional[list[NodeFieldValue]]:
|
||||||
|
field_values_raw = queue_item_dict.get("field_values", None)
|
||||||
|
return parse_raw_as(list[NodeFieldValue], field_values_raw) if field_values_raw is not None else None
|
||||||
|
|
||||||
|
|
||||||
|
def get_session(queue_item_dict: dict) -> GraphExecutionState:
|
||||||
|
session_raw = queue_item_dict.get("session", "{}")
|
||||||
|
return parse_raw_as(GraphExecutionState, session_raw)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionQueueItemWithoutGraph(BaseModel):
|
||||||
|
"""Session queue item without the full graph. Used for serialization."""
|
||||||
|
|
||||||
|
item_id: str = Field(description="The unique identifier of the session queue item")
|
||||||
|
order_id: int = Field(description="The auto-incrementing ID of the session queue item")
|
||||||
|
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
|
||||||
|
priority: int = Field(default=0, description="The priority of this queue item")
|
||||||
|
batch_id: str = Field(description="The ID of the batch associated with this queue item")
|
||||||
|
session_id: str = Field(
|
||||||
|
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
|
||||||
|
)
|
||||||
|
field_values: Optional[list[NodeFieldValue]] = Field(
|
||||||
|
default=None, description="The field values that were used for this queue item"
|
||||||
|
)
|
||||||
|
queue_id: str = Field(description="The id of the queue with which this item is associated")
|
||||||
|
error: Optional[str] = Field(default=None, description="The error message if this queue item errored")
|
||||||
|
created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created")
|
||||||
|
updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated")
|
||||||
|
started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started")
|
||||||
|
completed_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was completed")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
|
||||||
|
# must parse these manually
|
||||||
|
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
|
||||||
|
return SessionQueueItemDTO(**queue_item_dict)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"required": [
|
||||||
|
"item_id",
|
||||||
|
"order_id",
|
||||||
|
"status",
|
||||||
|
"batch_id",
|
||||||
|
"queue_id",
|
||||||
|
"session_id",
|
||||||
|
"priority",
|
||||||
|
"session_id",
|
||||||
|
"created_at",
|
||||||
|
"updated_at",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SessionQueueItem(SessionQueueItemWithoutGraph):
|
||||||
|
session: GraphExecutionState = Field(description="The fully-populated session to be executed")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, queue_item_dict: dict) -> "SessionQueueItem":
|
||||||
|
# must parse these manually
|
||||||
|
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
|
||||||
|
queue_item_dict["session"] = get_session(queue_item_dict)
|
||||||
|
return SessionQueueItem(**queue_item_dict)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"required": [
|
||||||
|
"item_id",
|
||||||
|
"order_id",
|
||||||
|
"status",
|
||||||
|
"batch_id",
|
||||||
|
"queue_id",
|
||||||
|
"session_id",
|
||||||
|
"session",
|
||||||
|
"priority",
|
||||||
|
"session_id",
|
||||||
|
"created_at",
|
||||||
|
"updated_at",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# endregion Queue Items
|
||||||
|
|
||||||
|
# region Query Results
|
||||||
|
|
||||||
|
|
||||||
|
class SessionQueueStatus(BaseModel):
|
||||||
|
queue_id: str = Field(..., description="The ID of the queue")
|
||||||
|
item_id: Optional[str] = Field(description="The current queue item id")
|
||||||
|
batch_id: Optional[str] = Field(description="The current queue item's batch id")
|
||||||
|
session_id: Optional[str] = Field(description="The current queue item's session id")
|
||||||
|
pending: int = Field(..., description="Number of queue items with status 'pending'")
|
||||||
|
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
|
||||||
|
completed: int = Field(..., description="Number of queue items with status 'complete'")
|
||||||
|
failed: int = Field(..., description="Number of queue items with status 'error'")
|
||||||
|
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
|
||||||
|
total: int = Field(..., description="Total number of queue items")
|
||||||
|
|
||||||
|
|
||||||
|
class BatchStatus(BaseModel):
|
||||||
|
queue_id: str = Field(..., description="The ID of the queue")
|
||||||
|
batch_id: str = Field(..., description="The ID of the batch")
|
||||||
|
pending: int = Field(..., description="Number of queue items with status 'pending'")
|
||||||
|
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
|
||||||
|
completed: int = Field(..., description="Number of queue items with status 'complete'")
|
||||||
|
failed: int = Field(..., description="Number of queue items with status 'error'")
|
||||||
|
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
|
||||||
|
total: int = Field(..., description="Total number of queue items")
|
||||||
|
|
||||||
|
|
||||||
|
class EnqueueBatchResult(BaseModel):
|
||||||
|
queue_id: str = Field(description="The ID of the queue")
|
||||||
|
enqueued: int = Field(description="The total number of queue items enqueued")
|
||||||
|
requested: int = Field(description="The total number of queue items requested to be enqueued")
|
||||||
|
batch: Batch = Field(description="The batch that was enqueued")
|
||||||
|
priority: int = Field(description="The priority of the enqueued batch")
|
||||||
|
|
||||||
|
|
||||||
|
class EnqueueGraphResult(BaseModel):
|
||||||
|
enqueued: int = Field(description="The total number of queue items enqueued")
|
||||||
|
requested: int = Field(description="The total number of queue items requested to be enqueued")
|
||||||
|
batch: Batch = Field(description="The batch that was enqueued")
|
||||||
|
priority: int = Field(description="The priority of the enqueued batch")
|
||||||
|
queue_item: SessionQueueItemDTO = Field(description="The queue item that was enqueued")
|
||||||
|
|
||||||
|
|
||||||
|
class ClearResult(BaseModel):
|
||||||
|
"""Result of clearing the session queue"""
|
||||||
|
|
||||||
|
deleted: int = Field(..., description="Number of queue items deleted")
|
||||||
|
|
||||||
|
|
||||||
|
class PruneResult(ClearResult):
|
||||||
|
"""Result of pruning the session queue"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CancelByBatchIDsResult(BaseModel):
|
||||||
|
"""Result of canceling by list of batch ids"""
|
||||||
|
|
||||||
|
canceled: int = Field(..., description="Number of queue items canceled")
|
||||||
|
|
||||||
|
|
||||||
|
class CancelByQueueIDResult(CancelByBatchIDsResult):
|
||||||
|
"""Result of canceling by queue id"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class IsEmptyResult(BaseModel):
|
||||||
|
"""Result of checking if the session queue is empty"""
|
||||||
|
|
||||||
|
is_empty: bool = Field(..., description="Whether the session queue is empty")
|
||||||
|
|
||||||
|
|
||||||
|
class IsFullResult(BaseModel):
|
||||||
|
"""Result of checking if the session queue is full"""
|
||||||
|
|
||||||
|
is_full: bool = Field(..., description="Whether the session queue is full")
|
||||||
|
|
||||||
|
|
||||||
|
# endregion Query Results
|
||||||
|
|
||||||
|
|
||||||
|
# region Util
|
||||||
|
|
||||||
|
|
||||||
|
def populate_graph(graph: Graph, node_field_values: Iterable[NodeFieldValue]) -> Graph:
|
||||||
|
"""
|
||||||
|
Populates the given graph with the given batch data items.
|
||||||
|
"""
|
||||||
|
graph_clone = graph.copy(deep=True)
|
||||||
|
for item in node_field_values:
|
||||||
|
node = graph_clone.get_node(item.node_path)
|
||||||
|
if node is None:
|
||||||
|
continue
|
||||||
|
setattr(node, item.field_name, item.value)
|
||||||
|
graph_clone.update_node(item.node_path, node)
|
||||||
|
return graph_clone
|
||||||
|
|
||||||
|
|
||||||
|
def create_session_nfv_tuples(
|
||||||
|
batch: Batch, maximum: int
|
||||||
|
) -> Generator[tuple[GraphExecutionState, list[NodeFieldValue]], None, None]:
|
||||||
|
"""
|
||||||
|
Create all graph permutations from the given batch data and graph. Yields tuples
|
||||||
|
of the form (graph, batch_data_items) where batch_data_items is the list of BatchDataItems
|
||||||
|
that was applied to the graph.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO: Should this be a class method on Batch?
|
||||||
|
|
||||||
|
data: list[list[tuple[NodeFieldValue]]] = []
|
||||||
|
batch_data_collection = batch.data if batch.data is not None else []
|
||||||
|
for batch_datum_list in batch_data_collection:
|
||||||
|
# each batch_datum_list needs to be convered to NodeFieldValues and then zipped
|
||||||
|
|
||||||
|
node_field_values_to_zip: list[list[NodeFieldValue]] = []
|
||||||
|
for batch_datum in batch_datum_list:
|
||||||
|
node_field_values = [
|
||||||
|
NodeFieldValue(node_path=batch_datum.node_path, field_name=batch_datum.field_name, value=item)
|
||||||
|
for item in batch_datum.items
|
||||||
|
]
|
||||||
|
node_field_values_to_zip.append(node_field_values)
|
||||||
|
data.append(list(zip(*node_field_values_to_zip)))
|
||||||
|
|
||||||
|
# create generator to yield session,nfv tuples
|
||||||
|
count = 0
|
||||||
|
for _ in range(batch.runs):
|
||||||
|
for d in product(*data):
|
||||||
|
if count >= maximum:
|
||||||
|
return
|
||||||
|
flat_node_field_values = list(chain.from_iterable(d))
|
||||||
|
graph = populate_graph(batch.graph, flat_node_field_values)
|
||||||
|
yield (GraphExecutionState(graph=graph), flat_node_field_values)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
|
||||||
|
def calc_session_count(batch: Batch) -> int:
|
||||||
|
"""
|
||||||
|
Calculates the number of sessions that would be created by the batch, without incurring
|
||||||
|
the overhead of actually generating them. Adapted from `create_sessions().
|
||||||
|
"""
|
||||||
|
# TODO: Should this be a class method on Batch?
|
||||||
|
if not batch.data:
|
||||||
|
return batch.runs
|
||||||
|
data = []
|
||||||
|
for batch_datum_list in batch.data:
|
||||||
|
to_zip = []
|
||||||
|
for batch_datum in batch_datum_list:
|
||||||
|
batch_data_items = range(len(batch_datum.items))
|
||||||
|
to_zip.append(batch_data_items)
|
||||||
|
data.append(list(zip(*to_zip)))
|
||||||
|
data_product = list(product(*data))
|
||||||
|
return len(data_product) * batch.runs
|
||||||
|
|
||||||
|
|
||||||
|
class SessionQueueValueToInsert(NamedTuple):
|
||||||
|
"""A tuple of values to insert into the session_queue table"""
|
||||||
|
|
||||||
|
item_id: str # item_id
|
||||||
|
queue_id: str # queue_id
|
||||||
|
session: str # session json
|
||||||
|
session_id: str # session_id
|
||||||
|
batch_id: str # batch_id
|
||||||
|
field_values: Optional[str] # field_values json
|
||||||
|
priority: int # priority
|
||||||
|
order_id: int # order_id
|
||||||
|
|
||||||
|
|
||||||
|
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_values_to_insert(
|
||||||
|
queue_id: str, batch: Batch, priority: int, max_new_queue_items: int, order_id: int
|
||||||
|
) -> ValuesToInsert:
|
||||||
|
values_to_insert: ValuesToInsert = []
|
||||||
|
for session, field_values in create_session_nfv_tuples(batch, max_new_queue_items):
|
||||||
|
# sessions must have unique id
|
||||||
|
session.id = uuid_string()
|
||||||
|
values_to_insert.append(
|
||||||
|
SessionQueueValueToInsert(
|
||||||
|
uuid_string(), # item_id
|
||||||
|
queue_id, # queue_id
|
||||||
|
session.json(), # session (json)
|
||||||
|
session.id, # session_id
|
||||||
|
batch.batch_id, # batch_id
|
||||||
|
# must use pydantic_encoder bc field_values is a list of models
|
||||||
|
json.dumps(field_values, default=pydantic_encoder) if field_values else None, # field_values (json)
|
||||||
|
priority, # priority
|
||||||
|
order_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
order_id += 1
|
||||||
|
return values_to_insert
|
||||||
|
|
||||||
|
|
||||||
|
# endregion Util
|
829
invokeai/app/services/session_queue/session_queue_sqlite.py
Normal file
829
invokeai/app/services/session_queue/session_queue_sqlite.py
Normal file
@ -0,0 +1,829 @@
|
|||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
|
from fastapi_events.handlers.local import local_handler
|
||||||
|
from fastapi_events.typing import Event as FastAPIEvent
|
||||||
|
|
||||||
|
from invokeai.app.services.events import EventServiceBase
|
||||||
|
from invokeai.app.services.graph import Graph
|
||||||
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
||||||
|
from invokeai.app.services.session_queue.session_queue_common import (
|
||||||
|
DEFAULT_QUEUE_ID,
|
||||||
|
QUEUE_ITEM_STATUS,
|
||||||
|
Batch,
|
||||||
|
BatchStatus,
|
||||||
|
CancelByBatchIDsResult,
|
||||||
|
CancelByQueueIDResult,
|
||||||
|
ClearResult,
|
||||||
|
EnqueueBatchResult,
|
||||||
|
EnqueueGraphResult,
|
||||||
|
IsEmptyResult,
|
||||||
|
IsFullResult,
|
||||||
|
PruneResult,
|
||||||
|
SessionQueueItem,
|
||||||
|
SessionQueueItemDTO,
|
||||||
|
SessionQueueItemNotFoundError,
|
||||||
|
SessionQueueStatus,
|
||||||
|
calc_session_count,
|
||||||
|
prepare_values_to_insert,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.shared.models import CursorPaginatedResults
|
||||||
|
|
||||||
|
|
||||||
|
class SqliteSessionQueue(SessionQueueBase):
|
||||||
|
__invoker: Invoker
|
||||||
|
__conn: sqlite3.Connection
|
||||||
|
__cursor: sqlite3.Cursor
|
||||||
|
__lock: threading.Lock
|
||||||
|
|
||||||
|
def start(self, invoker: Invoker) -> None:
|
||||||
|
self.__invoker = invoker
|
||||||
|
self._set_in_progress_to_canceled()
|
||||||
|
prune_result = self.prune(DEFAULT_QUEUE_ID)
|
||||||
|
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
|
||||||
|
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
||||||
|
|
||||||
|
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.__conn = conn
|
||||||
|
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||||
|
self.__conn.row_factory = sqlite3.Row
|
||||||
|
self.__cursor = self.__conn.cursor()
|
||||||
|
self.__lock = lock
|
||||||
|
self._create_tables()
|
||||||
|
|
||||||
|
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
|
||||||
|
return event[1]["event"] in match_in
|
||||||
|
|
||||||
|
async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent:
|
||||||
|
event_name = event[1]["event"]
|
||||||
|
match event_name:
|
||||||
|
case "graph_execution_state_complete":
|
||||||
|
await self._handle_complete_event(event)
|
||||||
|
case "invocation_error" | "session_retrieval_error" | "invocation_retrieval_error":
|
||||||
|
await self._handle_error_event(event)
|
||||||
|
case "session_canceled":
|
||||||
|
await self._handle_cancel_event(event)
|
||||||
|
return event
|
||||||
|
|
||||||
|
async def _handle_complete_event(self, event: FastAPIEvent) -> None:
|
||||||
|
try:
|
||||||
|
item_id = event[1]["data"]["queue_item_id"]
|
||||||
|
# When a queue item has an error, we get an error event, then a completed event.
|
||||||
|
# Mark the queue item completed only if it isn't already marked completed, e.g.
|
||||||
|
# by a previously-handled error event.
|
||||||
|
queue_item = self.get_queue_item(item_id)
|
||||||
|
if queue_item.status not in ["completed", "failed", "canceled"]:
|
||||||
|
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed")
|
||||||
|
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
|
||||||
|
except SessionQueueItemNotFoundError:
|
||||||
|
return
|
||||||
|
|
||||||
|
async def _handle_error_event(self, event: FastAPIEvent) -> None:
|
||||||
|
try:
|
||||||
|
item_id = event[1]["data"]["queue_item_id"]
|
||||||
|
error = event[1]["data"]["error"]
|
||||||
|
queue_item = self.get_queue_item(item_id)
|
||||||
|
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error)
|
||||||
|
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
|
||||||
|
except SessionQueueItemNotFoundError:
|
||||||
|
return
|
||||||
|
|
||||||
|
async def _handle_cancel_event(self, event: FastAPIEvent) -> None:
|
||||||
|
try:
|
||||||
|
item_id = event[1]["data"]["queue_item_id"]
|
||||||
|
queue_item = self.get_queue_item(item_id)
|
||||||
|
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled")
|
||||||
|
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
|
||||||
|
except SessionQueueItemNotFoundError:
|
||||||
|
return
|
||||||
|
|
||||||
|
def _create_tables(self) -> None:
|
||||||
|
"""Creates the session queue tables, indicies, and triggers"""
|
||||||
|
try:
|
||||||
|
self.__lock.acquire()
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE TABLE IF NOT EXISTS session_queue (
|
||||||
|
item_id TEXT NOT NULL PRIMARY KEY, -- the unique identifier of this queue item
|
||||||
|
order_id INTEGER NOT NULL, -- used for ordering, cursor pagination
|
||||||
|
batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to
|
||||||
|
queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to
|
||||||
|
session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access
|
||||||
|
field_values TEXT, -- NULL if no values are associated with this queue item
|
||||||
|
session TEXT NOT NULL, -- the session to be executed
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'completed', 'failed', 'canceled'
|
||||||
|
priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important
|
||||||
|
error TEXT, -- any errors associated with this queue item
|
||||||
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger
|
||||||
|
started_at DATETIME, -- updated via trigger
|
||||||
|
completed_at DATETIME -- updated via trigger, completed items are cleaned up on application startup
|
||||||
|
-- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE...
|
||||||
|
-- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_order_id ON session_queue(order_id);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at
|
||||||
|
AFTER UPDATE OF status ON session_queue
|
||||||
|
FOR EACH ROW
|
||||||
|
WHEN
|
||||||
|
NEW.status = 'completed'
|
||||||
|
OR NEW.status = 'failed'
|
||||||
|
OR NEW.status = 'canceled'
|
||||||
|
BEGIN
|
||||||
|
UPDATE session_queue
|
||||||
|
SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||||
|
WHERE item_id = NEW.item_id;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE TRIGGER IF NOT EXISTS tg_session_queue_started_at
|
||||||
|
AFTER UPDATE OF status ON session_queue
|
||||||
|
FOR EACH ROW
|
||||||
|
WHEN
|
||||||
|
NEW.status = 'in_progress'
|
||||||
|
BEGIN
|
||||||
|
UPDATE session_queue
|
||||||
|
SET started_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||||
|
WHERE item_id = NEW.item_id;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at
|
||||||
|
AFTER UPDATE
|
||||||
|
ON session_queue FOR EACH ROW
|
||||||
|
BEGIN
|
||||||
|
UPDATE session_queue
|
||||||
|
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||||
|
WHERE item_id = old.item_id;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
self.__conn.commit()
|
||||||
|
except Exception:
|
||||||
|
self.__conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self.__lock.release()
|
||||||
|
|
||||||
|
def _set_in_progress_to_canceled(self) -> None:
|
||||||
|
"""
|
||||||
|
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
|
||||||
|
This is necessary because the invoker may have been killed while processing a queue item.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.__lock.acquire()
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
UPDATE session_queue
|
||||||
|
SET status = 'canceled'
|
||||||
|
WHERE status = 'in_progress';
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
self.__conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self.__lock.release()
|
||||||
|
|
||||||
|
def _get_current_queue_size(self, queue_id: str) -> int:
|
||||||
|
"""Gets the current number of pending queue items"""
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT count(*)
|
||||||
|
FROM session_queue
|
||||||
|
WHERE
|
||||||
|
queue_id = ?
|
||||||
|
AND status = 'pending'
|
||||||
|
""",
|
||||||
|
(queue_id,),
|
||||||
|
)
|
||||||
|
return cast(int, self.__cursor.fetchone()[0])
|
||||||
|
|
||||||
|
def _get_highest_priority(self, queue_id: str) -> int:
|
||||||
|
"""Gets the highest priority value in the queue"""
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT MAX(priority)
|
||||||
|
FROM session_queue
|
||||||
|
WHERE
|
||||||
|
queue_id = ?
|
||||||
|
AND status = 'pending'
|
||||||
|
""",
|
||||||
|
(queue_id,),
|
||||||
|
)
|
||||||
|
return cast(Union[int, None], self.__cursor.fetchone()[0]) or 0
|
||||||
|
|
||||||
|
def enqueue_graph(self, queue_id: str, graph: Graph, prepend: bool) -> EnqueueGraphResult:
|
||||||
|
enqueue_result = self.enqueue_batch(queue_id=queue_id, batch=Batch(graph=graph), prepend=prepend)
|
||||||
|
try:
|
||||||
|
self.__lock.acquire()
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT *
|
||||||
|
FROM session_queue
|
||||||
|
WHERE queue_id = ?
|
||||||
|
AND batch_id = ?
|
||||||
|
""",
|
||||||
|
(queue_id, enqueue_result.batch.batch_id),
|
||||||
|
)
|
||||||
|
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||||
|
except Exception:
|
||||||
|
self.__conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self.__lock.release()
|
||||||
|
if result is None:
|
||||||
|
raise SessionQueueItemNotFoundError(f"No queue item with batch id {enqueue_result.batch.batch_id}")
|
||||||
|
return EnqueueGraphResult(
|
||||||
|
**enqueue_result.dict(),
|
||||||
|
queue_item=SessionQueueItemDTO.from_dict(dict(result)),
|
||||||
|
)
|
||||||
|
|
||||||
|
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
|
||||||
|
try:
|
||||||
|
self.__lock.acquire()
|
||||||
|
|
||||||
|
# TODO: how does this work in a multi-user scenario?
|
||||||
|
current_queue_size = self._get_current_queue_size(queue_id)
|
||||||
|
max_queue_size = self.__invoker.services.configuration.get_config().max_queue_size
|
||||||
|
max_new_queue_items = max_queue_size - current_queue_size
|
||||||
|
|
||||||
|
priority = 0
|
||||||
|
if prepend:
|
||||||
|
priority = self._get_highest_priority(queue_id) + 1
|
||||||
|
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT MAX(order_id)
|
||||||
|
FROM session_queue
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
max_order_id = cast(Optional[int], self.__cursor.fetchone()[0]) or 0
|
||||||
|
|
||||||
|
requested_count = calc_session_count(batch)
|
||||||
|
values_to_insert = prepare_values_to_insert(
|
||||||
|
queue_id=queue_id,
|
||||||
|
batch=batch,
|
||||||
|
priority=priority,
|
||||||
|
max_new_queue_items=max_new_queue_items,
|
||||||
|
order_id=max_order_id + 1,
|
||||||
|
)
|
||||||
|
enqueued_count = len(values_to_insert)
|
||||||
|
|
||||||
|
if requested_count > enqueued_count:
|
||||||
|
values_to_insert = values_to_insert[:max_new_queue_items]
|
||||||
|
|
||||||
|
self.__cursor.executemany(
|
||||||
|
"""--sql
|
||||||
|
INSERT INTO session_queue (item_id, queue_id, session, session_id, batch_id, field_values, priority, order_id)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
values_to_insert,
|
||||||
|
)
|
||||||
|
self.__conn.commit()
|
||||||
|
except Exception:
|
||||||
|
self.__conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self.__lock.release()
|
||||||
|
enqueue_result = EnqueueBatchResult(
|
||||||
|
queue_id=queue_id,
|
||||||
|
requested=requested_count,
|
||||||
|
enqueued=enqueued_count,
|
||||||
|
batch=batch,
|
||||||
|
priority=priority,
|
||||||
|
)
|
||||||
|
self.__invoker.services.events.emit_batch_enqueued(enqueue_result)
|
||||||
|
return enqueue_result
|
||||||
|
|
||||||
|
def dequeue(self) -> Optional[SessionQueueItem]:
|
||||||
|
try:
|
||||||
|
self.__lock.acquire()
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT *
|
||||||
|
FROM session_queue
|
||||||
|
WHERE status = 'pending'
|
||||||
|
ORDER BY
|
||||||
|
priority DESC,
|
||||||
|
order_id ASC
|
||||||
|
LIMIT 1
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||||
|
except Exception:
|
||||||
|
self.__conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self.__lock.release()
|
||||||
|
if result is None:
|
||||||
|
return None
|
||||||
|
queue_item = SessionQueueItem.from_dict(dict(result))
|
||||||
|
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress")
|
||||||
|
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
|
||||||
|
return queue_item
|
||||||
|
|
||||||
|
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||||
|
try:
|
||||||
|
self.__lock.acquire()
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT *
|
||||||
|
FROM session_queue
|
||||||
|
WHERE
|
||||||
|
queue_id = ?
|
||||||
|
AND status = 'pending'
|
||||||
|
ORDER BY
|
||||||
|
priority DESC,
|
||||||
|
created_at ASC
|
||||||
|
LIMIT 1
|
||||||
|
""",
|
||||||
|
(queue_id,),
|
||||||
|
)
|
||||||
|
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||||
|
except Exception:
|
||||||
|
self.__conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self.__lock.release()
|
||||||
|
if result is None:
|
||||||
|
return None
|
||||||
|
return SessionQueueItem.from_dict(dict(result))
|
||||||
|
|
||||||
|
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||||
|
try:
|
||||||
|
self.__lock.acquire()
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT *
|
||||||
|
FROM session_queue
|
||||||
|
WHERE
|
||||||
|
queue_id = ?
|
||||||
|
AND status = 'in_progress'
|
||||||
|
LIMIT 1
|
||||||
|
""",
|
||||||
|
(queue_id,),
|
||||||
|
)
|
||||||
|
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||||
|
except Exception:
|
||||||
|
self.__conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self.__lock.release()
|
||||||
|
if result is None:
|
||||||
|
return None
|
||||||
|
return SessionQueueItem.from_dict(dict(result))
|
||||||
|
|
||||||
|
def _set_queue_item_status(
|
||||||
|
self, item_id: str, status: QUEUE_ITEM_STATUS, error: Optional[str] = None
|
||||||
|
) -> SessionQueueItem:
|
||||||
|
try:
|
||||||
|
self.__lock.acquire()
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
UPDATE session_queue
|
||||||
|
SET status = ?, error = ?
|
||||||
|
WHERE item_id = ?
|
||||||
|
""",
|
||||||
|
(status, error, item_id),
|
||||||
|
)
|
||||||
|
self.__conn.commit()
|
||||||
|
except Exception:
|
||||||
|
self.__conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self.__lock.release()
|
||||||
|
return self.get_queue_item(item_id)
|
||||||
|
|
||||||
|
def is_empty(self, queue_id: str) -> IsEmptyResult:
|
||||||
|
try:
|
||||||
|
self.__lock.acquire()
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT count(*)
|
||||||
|
FROM session_queue
|
||||||
|
WHERE queue_id = ?
|
||||||
|
""",
|
||||||
|
(queue_id,),
|
||||||
|
)
|
||||||
|
is_empty = cast(int, self.__cursor.fetchone()[0]) == 0
|
||||||
|
except Exception:
|
||||||
|
self.__conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self.__lock.release()
|
||||||
|
return IsEmptyResult(is_empty=is_empty)
|
||||||
|
|
||||||
|
def is_full(self, queue_id: str) -> IsFullResult:
|
||||||
|
try:
|
||||||
|
self.__lock.acquire()
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT count(*)
|
||||||
|
FROM session_queue
|
||||||
|
WHERE queue_id = ?
|
||||||
|
""",
|
||||||
|
(queue_id,),
|
||||||
|
)
|
||||||
|
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||||
|
is_full = cast(int, self.__cursor.fetchone()[0]) >= max_queue_size
|
||||||
|
except Exception:
|
||||||
|
self.__conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self.__lock.release()
|
||||||
|
return IsFullResult(is_full=is_full)
|
||||||
|
|
||||||
|
def delete_queue_item(self, item_id: str) -> SessionQueueItem:
|
||||||
|
queue_item = self.get_queue_item(item_id=item_id)
|
||||||
|
try:
|
||||||
|
self.__lock.acquire()
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
DELETE FROM session_queue
|
||||||
|
WHERE
|
||||||
|
item_id = ?
|
||||||
|
""",
|
||||||
|
(item_id,),
|
||||||
|
)
|
||||||
|
self.__conn.commit()
|
||||||
|
except Exception:
|
||||||
|
self.__conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self.__lock.release()
|
||||||
|
return queue_item
|
||||||
|
|
||||||
|
def clear(self, queue_id: str) -> ClearResult:
|
||||||
|
try:
|
||||||
|
self.__lock.acquire()
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM session_queue
|
||||||
|
WHERE queue_id = ?
|
||||||
|
""",
|
||||||
|
(queue_id,),
|
||||||
|
)
|
||||||
|
count = self.__cursor.fetchone()[0]
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
DELETE
|
||||||
|
FROM session_queue
|
||||||
|
WHERE queue_id = ?
|
||||||
|
""",
|
||||||
|
(queue_id,),
|
||||||
|
)
|
||||||
|
self.__conn.commit()
|
||||||
|
except Exception:
|
||||||
|
self.__conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self.__lock.release()
|
||||||
|
self.__invoker.services.events.emit_queue_cleared(queue_id)
|
||||||
|
return ClearResult(deleted=count)
|
||||||
|
|
||||||
|
def prune(self, queue_id: str) -> PruneResult:
|
||||||
|
try:
|
||||||
|
where = """--sql
|
||||||
|
WHERE
|
||||||
|
queue_id = ?
|
||||||
|
AND (
|
||||||
|
status = 'completed'
|
||||||
|
OR status = 'failed'
|
||||||
|
OR status = 'canceled'
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
self.__lock.acquire()
|
||||||
|
self.__cursor.execute(
|
||||||
|
f"""--sql
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM session_queue
|
||||||
|
{where};
|
||||||
|
""",
|
||||||
|
(queue_id,),
|
||||||
|
)
|
||||||
|
count = self.__cursor.fetchone()[0]
|
||||||
|
self.__cursor.execute(
|
||||||
|
f"""--sql
|
||||||
|
DELETE
|
||||||
|
FROM session_queue
|
||||||
|
{where};
|
||||||
|
""",
|
||||||
|
(queue_id,),
|
||||||
|
)
|
||||||
|
self.__conn.commit()
|
||||||
|
except Exception:
|
||||||
|
self.__conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self.__lock.release()
|
||||||
|
return PruneResult(deleted=count)
|
||||||
|
|
||||||
|
def cancel_queue_item(self, item_id: str) -> SessionQueueItem:
|
||||||
|
queue_item = self.get_queue_item(item_id)
|
||||||
|
if queue_item.status not in ["canceled", "failed", "completed"]:
|
||||||
|
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
|
||||||
|
self.__invoker.services.queue.cancel(queue_item.session_id)
|
||||||
|
self.__invoker.services.events.emit_session_canceled(
|
||||||
|
queue_item_id=queue_item.item_id,
|
||||||
|
queue_id=queue_item.queue_id,
|
||||||
|
graph_execution_state_id=queue_item.session_id,
|
||||||
|
)
|
||||||
|
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
|
||||||
|
return queue_item
|
||||||
|
|
||||||
|
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
|
||||||
|
try:
|
||||||
|
current_queue_item = self.get_current(queue_id)
|
||||||
|
self.__lock.acquire()
|
||||||
|
placeholders = ", ".join(["?" for _ in batch_ids])
|
||||||
|
where = f"""--sql
|
||||||
|
WHERE
|
||||||
|
queue_id == ?
|
||||||
|
AND batch_id IN ({placeholders})
|
||||||
|
AND status != 'canceled'
|
||||||
|
AND status != 'completed'
|
||||||
|
AND status != 'failed'
|
||||||
|
"""
|
||||||
|
params = [queue_id] + batch_ids
|
||||||
|
self.__cursor.execute(
|
||||||
|
f"""--sql
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM session_queue
|
||||||
|
{where};
|
||||||
|
""",
|
||||||
|
tuple(params),
|
||||||
|
)
|
||||||
|
count = self.__cursor.fetchone()[0]
|
||||||
|
self.__cursor.execute(
|
||||||
|
f"""--sql
|
||||||
|
UPDATE session_queue
|
||||||
|
SET status = 'canceled'
|
||||||
|
{where};
|
||||||
|
""",
|
||||||
|
tuple(params),
|
||||||
|
)
|
||||||
|
self.__conn.commit()
|
||||||
|
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||||
|
self.__invoker.services.queue.cancel(current_queue_item.session_id)
|
||||||
|
self.__invoker.services.events.emit_session_canceled(
|
||||||
|
queue_item_id=current_queue_item.item_id,
|
||||||
|
queue_id=current_queue_item.queue_id,
|
||||||
|
graph_execution_state_id=current_queue_item.session_id,
|
||||||
|
)
|
||||||
|
self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item)
|
||||||
|
except Exception:
|
||||||
|
self.__conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self.__lock.release()
|
||||||
|
return CancelByBatchIDsResult(canceled=count)
|
||||||
|
|
||||||
|
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||||
|
try:
|
||||||
|
current_queue_item = self.get_current(queue_id)
|
||||||
|
self.__lock.acquire()
|
||||||
|
where = """--sql
|
||||||
|
WHERE
|
||||||
|
queue_id is ?
|
||||||
|
AND status != 'canceled'
|
||||||
|
AND status != 'completed'
|
||||||
|
AND status != 'failed'
|
||||||
|
"""
|
||||||
|
params = [queue_id]
|
||||||
|
self.__cursor.execute(
|
||||||
|
f"""--sql
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM session_queue
|
||||||
|
{where};
|
||||||
|
""",
|
||||||
|
tuple(params),
|
||||||
|
)
|
||||||
|
count = self.__cursor.fetchone()[0]
|
||||||
|
self.__cursor.execute(
|
||||||
|
f"""--sql
|
||||||
|
UPDATE session_queue
|
||||||
|
SET status = 'canceled'
|
||||||
|
{where};
|
||||||
|
""",
|
||||||
|
tuple(params),
|
||||||
|
)
|
||||||
|
self.__conn.commit()
|
||||||
|
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
||||||
|
self.__invoker.services.queue.cancel(current_queue_item.session_id)
|
||||||
|
self.__invoker.services.events.emit_session_canceled(
|
||||||
|
queue_item_id=current_queue_item.item_id,
|
||||||
|
queue_id=current_queue_item.queue_id,
|
||||||
|
graph_execution_state_id=current_queue_item.session_id,
|
||||||
|
)
|
||||||
|
self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item)
|
||||||
|
except Exception:
|
||||||
|
self.__conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self.__lock.release()
|
||||||
|
return CancelByQueueIDResult(canceled=count)
|
||||||
|
|
||||||
|
def get_queue_item(self, item_id: str) -> SessionQueueItem:
|
||||||
|
try:
|
||||||
|
self.__lock.acquire()
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT * FROM session_queue
|
||||||
|
WHERE
|
||||||
|
item_id = ?
|
||||||
|
""",
|
||||||
|
(item_id,),
|
||||||
|
)
|
||||||
|
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||||
|
except Exception:
|
||||||
|
self.__conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self.__lock.release()
|
||||||
|
if result is None:
|
||||||
|
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||||
|
return SessionQueueItem.from_dict(dict(result))
|
||||||
|
|
||||||
|
def list_queue_items(
|
||||||
|
self,
|
||||||
|
queue_id: str,
|
||||||
|
limit: int,
|
||||||
|
priority: int,
|
||||||
|
order_id: Optional[int] = None,
|
||||||
|
status: Optional[QUEUE_ITEM_STATUS] = None,
|
||||||
|
) -> CursorPaginatedResults[SessionQueueItemDTO]:
|
||||||
|
try:
|
||||||
|
self.__lock.acquire()
|
||||||
|
query = """--sql
|
||||||
|
SELECT item_id,
|
||||||
|
order_id,
|
||||||
|
status,
|
||||||
|
priority,
|
||||||
|
field_values,
|
||||||
|
error,
|
||||||
|
created_at,
|
||||||
|
updated_at,
|
||||||
|
completed_at,
|
||||||
|
started_at,
|
||||||
|
session_id,
|
||||||
|
batch_id,
|
||||||
|
queue_id
|
||||||
|
FROM session_queue
|
||||||
|
WHERE queue_id = ?
|
||||||
|
"""
|
||||||
|
params: list[Union[str, int]] = [queue_id]
|
||||||
|
|
||||||
|
if status is not None:
|
||||||
|
query += """--sql
|
||||||
|
AND status = ?
|
||||||
|
"""
|
||||||
|
params.append(status)
|
||||||
|
|
||||||
|
if order_id is not None:
|
||||||
|
query += """--sql
|
||||||
|
AND (priority < ?) OR (priority = ? AND order_id > ?)
|
||||||
|
"""
|
||||||
|
params.extend([priority, priority, order_id])
|
||||||
|
|
||||||
|
query += """--sql
|
||||||
|
ORDER BY
|
||||||
|
priority DESC,
|
||||||
|
order_id ASC
|
||||||
|
LIMIT ?
|
||||||
|
"""
|
||||||
|
params.append(limit + 1)
|
||||||
|
self.__cursor.execute(query, params)
|
||||||
|
results = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||||
|
items = [SessionQueueItemDTO.from_dict(dict(result)) for result in results]
|
||||||
|
has_more = False
|
||||||
|
if len(items) > limit:
|
||||||
|
# remove the extra item
|
||||||
|
items.pop()
|
||||||
|
has_more = True
|
||||||
|
except Exception:
|
||||||
|
self.__conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self.__lock.release()
|
||||||
|
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
|
||||||
|
|
||||||
|
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
|
||||||
|
try:
|
||||||
|
self.__lock.acquire()
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT status, count(*)
|
||||||
|
FROM session_queue
|
||||||
|
WHERE queue_id = ?
|
||||||
|
GROUP BY status
|
||||||
|
""",
|
||||||
|
(queue_id,),
|
||||||
|
)
|
||||||
|
counts_result = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||||
|
except Exception:
|
||||||
|
self.__conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self.__lock.release()
|
||||||
|
|
||||||
|
current_item = self.get_current(queue_id=queue_id)
|
||||||
|
total = sum(row[1] for row in counts_result)
|
||||||
|
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
|
||||||
|
return SessionQueueStatus(
|
||||||
|
queue_id=queue_id,
|
||||||
|
item_id=current_item.item_id if current_item else None,
|
||||||
|
session_id=current_item.session_id if current_item else None,
|
||||||
|
batch_id=current_item.batch_id if current_item else None,
|
||||||
|
pending=counts.get("pending", 0),
|
||||||
|
in_progress=counts.get("in_progress", 0),
|
||||||
|
completed=counts.get("completed", 0),
|
||||||
|
failed=counts.get("failed", 0),
|
||||||
|
canceled=counts.get("canceled", 0),
|
||||||
|
total=total,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
|
||||||
|
try:
|
||||||
|
self.__lock.acquire()
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT status, count(*)
|
||||||
|
FROM session_queue
|
||||||
|
WHERE
|
||||||
|
queue_id = ?
|
||||||
|
AND batch_id = ?
|
||||||
|
GROUP BY status
|
||||||
|
""",
|
||||||
|
(queue_id, batch_id),
|
||||||
|
)
|
||||||
|
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||||
|
total = sum(row[1] for row in result)
|
||||||
|
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
||||||
|
except Exception:
|
||||||
|
self.__conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self.__lock.release()
|
||||||
|
|
||||||
|
return BatchStatus(
|
||||||
|
batch_id=batch_id,
|
||||||
|
queue_id=queue_id,
|
||||||
|
pending=counts.get("pending", 0),
|
||||||
|
in_progress=counts.get("in_progress", 0),
|
||||||
|
completed=counts.get("completed", 0),
|
||||||
|
failed=counts.get("failed", 0),
|
||||||
|
canceled=counts.get("canceled", 0),
|
||||||
|
total=total,
|
||||||
|
)
|
14
invokeai/app/services/shared/models.py
Normal file
14
invokeai/app/services/shared/models.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from pydantic.generics import GenericModel
|
||||||
|
|
||||||
|
GenericBaseModel = TypeVar("GenericBaseModel", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
class CursorPaginatedResults(GenericModel, Generic[GenericBaseModel]):
|
||||||
|
"""Cursor-paginated results"""
|
||||||
|
|
||||||
|
limit: int = Field(..., description="Limit of items to get")
|
||||||
|
has_more: bool = Field(..., description="Whether there are more items available")
|
||||||
|
items: list[GenericBaseModel] = Field(..., description="Items")
|
@ -1,5 +1,5 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
from threading import Lock
|
import threading
|
||||||
from typing import Generic, Optional, TypeVar, get_args
|
from typing import Generic, Optional, TypeVar, get_args
|
||||||
|
|
||||||
from pydantic import BaseModel, parse_raw_as
|
from pydantic import BaseModel, parse_raw_as
|
||||||
@ -12,23 +12,19 @@ sqlite_memory = ":memory:"
|
|||||||
|
|
||||||
|
|
||||||
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||||
_filename: str
|
|
||||||
_table_name: str
|
_table_name: str
|
||||||
_conn: sqlite3.Connection
|
_conn: sqlite3.Connection
|
||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_id_field: str
|
_id_field: str
|
||||||
_lock: Lock
|
_lock: threading.Lock
|
||||||
|
|
||||||
def __init__(self, filename: str, table_name: str, id_field: str = "id"):
|
def __init__(self, conn: sqlite3.Connection, table_name: str, lock: threading.Lock, id_field: str = "id"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self._filename = filename
|
|
||||||
self._table_name = table_name
|
self._table_name = table_name
|
||||||
self._id_field = id_field # TODO: validate that T has this field
|
self._id_field = id_field # TODO: validate that T has this field
|
||||||
self._lock = Lock()
|
self._lock = lock
|
||||||
self._conn = sqlite3.connect(
|
self._conn = conn
|
||||||
self._filename, check_same_thread=False
|
|
||||||
) # TODO: figure out a better threading solution
|
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
|
|
||||||
self._create_table()
|
self._create_table()
|
||||||
@ -49,8 +45,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
|
|
||||||
def _parse_item(self, item: str) -> T:
|
def _parse_item(self, item: str) -> T:
|
||||||
item_type = get_args(self.__orig_class__)[0]
|
item_type = get_args(self.__orig_class__)[0]
|
||||||
parsed = parse_raw_as(item_type, item)
|
return parse_raw_as(item_type, item)
|
||||||
return parsed
|
|
||||||
|
|
||||||
def set(self, item: T):
|
def set(self, item: T):
|
||||||
try:
|
try:
|
||||||
|
3
invokeai/app/services/thread.py
Normal file
3
invokeai/app/services/thread.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
import threading
|
||||||
|
|
||||||
|
lock = threading.Lock()
|
@ -1,4 +1,5 @@
|
|||||||
import datetime
|
import datetime
|
||||||
|
import uuid
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -21,3 +22,8 @@ SEED_MAX = np.iinfo(np.uint32).max
|
|||||||
def get_random_seed():
|
def get_random_seed():
|
||||||
rng = np.random.default_rng(seed=None)
|
rng = np.random.default_rng(seed=None)
|
||||||
return int(rng.integers(0, SEED_MAX))
|
return int(rng.integers(0, SEED_MAX))
|
||||||
|
|
||||||
|
|
||||||
|
def uuid_string():
|
||||||
|
res = uuid.uuid4()
|
||||||
|
return str(res)
|
||||||
|
@ -110,6 +110,8 @@ def stable_diffusion_step_callback(
|
|||||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||||
|
|
||||||
context.services.events.emit_generator_progress(
|
context.services.events.emit_generator_progress(
|
||||||
|
queue_id=context.queue_id,
|
||||||
|
queue_item_id=context.queue_item_id,
|
||||||
graph_execution_state_id=context.graph_execution_state_id,
|
graph_execution_state_id=context.graph_execution_state_id,
|
||||||
node=node,
|
node=node,
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
|
@ -14,7 +14,6 @@ import os
|
|||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import uuid
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
@ -27,6 +26,7 @@ from prompt_toolkit.key_binding import KeyBindings
|
|||||||
from prompt_toolkit.shortcuts import message_dialog
|
from prompt_toolkit.shortcuts import message_dialog
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
@ -421,7 +421,7 @@ VALUES ('{filename}', 'internal', 'general', {width}, {height}, null, null, '{me
|
|||||||
return rows[0][0]
|
return rows[0][0]
|
||||||
else:
|
else:
|
||||||
board_date_string = datetime.datetime.utcnow().date().isoformat()
|
board_date_string = datetime.datetime.utcnow().date().isoformat()
|
||||||
new_board_id = str(uuid.uuid4())
|
new_board_id = uuid_string()
|
||||||
sql_insert_board = f"INSERT INTO boards (board_id, board_name, created_at, updated_at) VALUES ('{new_board_id}', '{board_name}', '{board_date_string}', '{board_date_string}')"
|
sql_insert_board = f"INSERT INTO boards (board_id, board_name, created_at, updated_at) VALUES ('{new_board_id}', '{board_name}', '{board_date_string}', '{board_date_string}')"
|
||||||
self.cursor.execute(sql_insert_board)
|
self.cursor.execute(sql_insert_board)
|
||||||
self.connection.commit()
|
self.connection.commit()
|
||||||
|
@ -13,14 +13,15 @@
|
|||||||
"reset": "Reset",
|
"reset": "Reset",
|
||||||
"rotateClockwise": "Rotate Clockwise",
|
"rotateClockwise": "Rotate Clockwise",
|
||||||
"rotateCounterClockwise": "Rotate Counter-Clockwise",
|
"rotateCounterClockwise": "Rotate Counter-Clockwise",
|
||||||
"showGallery": "Show Gallery",
|
"showGalleryPanel": "Show Gallery Panel",
|
||||||
"showOptionsPanel": "Show Side Panel",
|
"showOptionsPanel": "Show Side Panel",
|
||||||
"toggleAutoscroll": "Toggle autoscroll",
|
"toggleAutoscroll": "Toggle autoscroll",
|
||||||
"toggleLogViewer": "Toggle Log Viewer",
|
"toggleLogViewer": "Toggle Log Viewer",
|
||||||
"uploadImage": "Upload Image",
|
"uploadImage": "Upload Image",
|
||||||
"useThisParameter": "Use this parameter",
|
"useThisParameter": "Use this parameter",
|
||||||
"zoomIn": "Zoom In",
|
"zoomIn": "Zoom In",
|
||||||
"zoomOut": "Zoom Out"
|
"zoomOut": "Zoom Out",
|
||||||
|
"loadMore": "Load More"
|
||||||
},
|
},
|
||||||
"boards": {
|
"boards": {
|
||||||
"addBoard": "Add Board",
|
"addBoard": "Add Board",
|
||||||
@ -109,6 +110,7 @@
|
|||||||
"statusModelChanged": "Model Changed",
|
"statusModelChanged": "Model Changed",
|
||||||
"statusModelConverted": "Model Converted",
|
"statusModelConverted": "Model Converted",
|
||||||
"statusPreparing": "Preparing",
|
"statusPreparing": "Preparing",
|
||||||
|
"statusProcessing": "Processing",
|
||||||
"statusProcessingCanceled": "Processing Canceled",
|
"statusProcessingCanceled": "Processing Canceled",
|
||||||
"statusProcessingComplete": "Processing Complete",
|
"statusProcessingComplete": "Processing Complete",
|
||||||
"statusRestoringFaces": "Restoring Faces",
|
"statusRestoringFaces": "Restoring Faces",
|
||||||
@ -198,6 +200,63 @@
|
|||||||
"incompatibleModel": "Incompatible base model:",
|
"incompatibleModel": "Incompatible base model:",
|
||||||
"noMatchingEmbedding": "No matching Embeddings"
|
"noMatchingEmbedding": "No matching Embeddings"
|
||||||
},
|
},
|
||||||
|
"queue": {
|
||||||
|
"queue": "Queue",
|
||||||
|
"queueFront": "Add to Front of Queue",
|
||||||
|
"queueBack": "Add to Queue",
|
||||||
|
"queueCountPrediction": "Add {{predicted}} to Queue",
|
||||||
|
"queueMaxExceeded": "Max of {{max_queue_size}} exceeded, would skip {{skip}}",
|
||||||
|
"queuedCount": "{{pending}} Pending",
|
||||||
|
"queueTotal": "{{total}} Total",
|
||||||
|
"queueEmpty": "Queue Empty",
|
||||||
|
"enqueueing": "Queueing Batch",
|
||||||
|
"resume": "Resume",
|
||||||
|
"resumeTooltip": "Resume Processor",
|
||||||
|
"resumeSucceeded": "Processor Resumed",
|
||||||
|
"resumeFailed": "Problem Resuming Processor",
|
||||||
|
"pause": "Pause",
|
||||||
|
"pauseTooltip": "Pause Processor",
|
||||||
|
"pauseSucceeded": "Processor Paused",
|
||||||
|
"pauseFailed": "Problem Pausing Processor",
|
||||||
|
"cancel": "Cancel",
|
||||||
|
"cancelTooltip": "Cancel Current Item",
|
||||||
|
"cancelSucceeded": "Item Canceled",
|
||||||
|
"cancelFailed": "Problem Canceling Item",
|
||||||
|
"prune": "Prune",
|
||||||
|
"pruneTooltip": "Prune {{item_count}} Completed Items",
|
||||||
|
"pruneSucceeded": "Pruned {{item_count}} Completed Items from Queue",
|
||||||
|
"pruneFailed": "Problem Pruning Queue",
|
||||||
|
"clear": "Clear",
|
||||||
|
"clearTooltip": "Cancel and Clear All Items",
|
||||||
|
"clearSucceeded": "Queue Cleared",
|
||||||
|
"clearFailed": "Problem Clearing Queue",
|
||||||
|
"cancelBatch": "Cancel Batch",
|
||||||
|
"cancelItem": "Cancel Item",
|
||||||
|
"cancelBatchSucceeded": "Batch Canceled",
|
||||||
|
"cancelBatchFailed": "Problem Canceling Batch",
|
||||||
|
"current": "Current",
|
||||||
|
"next": "Next",
|
||||||
|
"status": "Status",
|
||||||
|
"total": "Total",
|
||||||
|
"pending": "Pending",
|
||||||
|
"in_progress": "In Progress",
|
||||||
|
"completed": "Completed",
|
||||||
|
"failed": "Failed",
|
||||||
|
"canceled": "Canceled",
|
||||||
|
"completedIn": "Completed in",
|
||||||
|
"batch": "Batch",
|
||||||
|
"item": "Item",
|
||||||
|
"session": "Session",
|
||||||
|
"batchValues": "Batch Values",
|
||||||
|
"notReady": "Unable to Queue",
|
||||||
|
"batchQueued": "Batch Queued",
|
||||||
|
"batchQueuedDesc": "Added {{item_count}} sessions to {{direction}} of queue",
|
||||||
|
"front": "front",
|
||||||
|
"back": "back",
|
||||||
|
"batchFailedToQueue": "Failed to Queue Batch",
|
||||||
|
"graphQueued": "Graph queued",
|
||||||
|
"graphFailedToQueue": "Failed to queue graph"
|
||||||
|
},
|
||||||
"gallery": {
|
"gallery": {
|
||||||
"allImagesLoaded": "All Images Loaded",
|
"allImagesLoaded": "All Images Loaded",
|
||||||
"assets": "Assets",
|
"assets": "Assets",
|
||||||
@ -636,7 +695,8 @@
|
|||||||
"collectionItemDescription": "TODO",
|
"collectionItemDescription": "TODO",
|
||||||
"colorCodeEdges": "Color-Code Edges",
|
"colorCodeEdges": "Color-Code Edges",
|
||||||
"colorCodeEdgesHelp": "Color-code edges according to their connected fields",
|
"colorCodeEdgesHelp": "Color-code edges according to their connected fields",
|
||||||
"colorCollectionDescription": "A collection of colors.",
|
"colorCollection": "A collection of colors.",
|
||||||
|
"colorCollectionDescription": "TODO",
|
||||||
"colorField": "Color",
|
"colorField": "Color",
|
||||||
"colorFieldDescription": "A RGBA color.",
|
"colorFieldDescription": "A RGBA color.",
|
||||||
"colorPolymorphic": "Color Polymorphic",
|
"colorPolymorphic": "Color Polymorphic",
|
||||||
@ -683,7 +743,8 @@
|
|||||||
"imageFieldDescription": "Images may be passed between nodes.",
|
"imageFieldDescription": "Images may be passed between nodes.",
|
||||||
"imagePolymorphic": "Image Polymorphic",
|
"imagePolymorphic": "Image Polymorphic",
|
||||||
"imagePolymorphicDescription": "A collection of images.",
|
"imagePolymorphicDescription": "A collection of images.",
|
||||||
"inputFields": "Input Feilds",
|
"inputField": "Input Field",
|
||||||
|
"inputFields": "Input Fields",
|
||||||
"inputMayOnlyHaveOneConnection": "Input may only have one connection",
|
"inputMayOnlyHaveOneConnection": "Input may only have one connection",
|
||||||
"inputNode": "Input Node",
|
"inputNode": "Input Node",
|
||||||
"integer": "Integer",
|
"integer": "Integer",
|
||||||
@ -701,6 +762,7 @@
|
|||||||
"latentsPolymorphicDescription": "Latents may be passed between nodes.",
|
"latentsPolymorphicDescription": "Latents may be passed between nodes.",
|
||||||
"loadingNodes": "Loading Nodes...",
|
"loadingNodes": "Loading Nodes...",
|
||||||
"loadWorkflow": "Load Workflow",
|
"loadWorkflow": "Load Workflow",
|
||||||
|
"noWorkflow": "No Workflow",
|
||||||
"loRAModelField": "LoRA",
|
"loRAModelField": "LoRA",
|
||||||
"loRAModelFieldDescription": "TODO",
|
"loRAModelFieldDescription": "TODO",
|
||||||
"mainModelField": "Model",
|
"mainModelField": "Model",
|
||||||
@ -722,14 +784,15 @@
|
|||||||
"noImageFoundState": "No initial image found in state",
|
"noImageFoundState": "No initial image found in state",
|
||||||
"noMatchingNodes": "No matching nodes",
|
"noMatchingNodes": "No matching nodes",
|
||||||
"noNodeSelected": "No node selected",
|
"noNodeSelected": "No node selected",
|
||||||
"noOpacity": "Node Opacity",
|
"nodeOpacity": "Node Opacity",
|
||||||
"noOutputRecorded": "No outputs recorded",
|
"noOutputRecorded": "No outputs recorded",
|
||||||
"noOutputSchemaName": "No output schema name found in ref object",
|
"noOutputSchemaName": "No output schema name found in ref object",
|
||||||
"notes": "Notes",
|
"notes": "Notes",
|
||||||
"notesDescription": "Add notes about your workflow",
|
"notesDescription": "Add notes about your workflow",
|
||||||
"oNNXModelField": "ONNX Model",
|
"oNNXModelField": "ONNX Model",
|
||||||
"oNNXModelFieldDescription": "ONNX model field.",
|
"oNNXModelFieldDescription": "ONNX model field.",
|
||||||
"outputFields": "Output Feilds",
|
"outputField": "Output Field",
|
||||||
|
"outputFields": "Output Fields",
|
||||||
"outputNode": "Output node",
|
"outputNode": "Output node",
|
||||||
"outputSchemaNotFound": "Output schema not found",
|
"outputSchemaNotFound": "Output schema not found",
|
||||||
"pickOne": "Pick One",
|
"pickOne": "Pick One",
|
||||||
@ -778,6 +841,7 @@
|
|||||||
"unknownNode": "Unknown Node",
|
"unknownNode": "Unknown Node",
|
||||||
"unknownTemplate": "Unknown Template",
|
"unknownTemplate": "Unknown Template",
|
||||||
"unkownInvocation": "Unknown Invocation type",
|
"unkownInvocation": "Unknown Invocation type",
|
||||||
|
"updateNode": "Update Node",
|
||||||
"updateApp": "Update App",
|
"updateApp": "Update App",
|
||||||
"vaeField": "Vae",
|
"vaeField": "Vae",
|
||||||
"vaeFieldDescription": "Vae submodel.",
|
"vaeFieldDescription": "Vae submodel.",
|
||||||
@ -852,6 +916,7 @@
|
|||||||
"noInitialImageSelected": "No initial image selected",
|
"noInitialImageSelected": "No initial image selected",
|
||||||
"noModelForControlNet": "ControlNet {{index}} has no model selected.",
|
"noModelForControlNet": "ControlNet {{index}} has no model selected.",
|
||||||
"noModelSelected": "No model selected",
|
"noModelSelected": "No model selected",
|
||||||
|
"noPrompts": "No prompts generated",
|
||||||
"noNodesInGraph": "No nodes in graph",
|
"noNodesInGraph": "No nodes in graph",
|
||||||
"readyToInvoke": "Ready to Invoke",
|
"readyToInvoke": "Ready to Invoke",
|
||||||
"systemBusy": "System busy",
|
"systemBusy": "System busy",
|
||||||
@ -870,7 +935,12 @@
|
|||||||
"perlinNoise": "Perlin Noise",
|
"perlinNoise": "Perlin Noise",
|
||||||
"positivePromptPlaceholder": "Positive Prompt",
|
"positivePromptPlaceholder": "Positive Prompt",
|
||||||
"randomizeSeed": "Randomize Seed",
|
"randomizeSeed": "Randomize Seed",
|
||||||
|
"manualSeed": "Manual Seed",
|
||||||
|
"randomSeed": "Random Seed",
|
||||||
"restoreFaces": "Restore Faces",
|
"restoreFaces": "Restore Faces",
|
||||||
|
"iterations": "Iterations",
|
||||||
|
"iterationsWithCount_one": "{{count}} Iteration",
|
||||||
|
"iterationsWithCount_other": "{{count}} Iterations",
|
||||||
"scale": "Scale",
|
"scale": "Scale",
|
||||||
"scaleBeforeProcessing": "Scale Before Processing",
|
"scaleBeforeProcessing": "Scale Before Processing",
|
||||||
"scaledHeight": "Scaled H",
|
"scaledHeight": "Scaled H",
|
||||||
@ -884,10 +954,11 @@
|
|||||||
"seamLowThreshold": "Low",
|
"seamLowThreshold": "Low",
|
||||||
"seed": "Seed",
|
"seed": "Seed",
|
||||||
"seedWeights": "Seed Weights",
|
"seedWeights": "Seed Weights",
|
||||||
|
"imageActions": "Image Actions",
|
||||||
"sendTo": "Send to",
|
"sendTo": "Send to",
|
||||||
"sendToImg2Img": "Send to Image to Image",
|
"sendToImg2Img": "Send to Image to Image",
|
||||||
"sendToUnifiedCanvas": "Send To Unified Canvas",
|
"sendToUnifiedCanvas": "Send To Unified Canvas",
|
||||||
"showOptionsPanel": "Show Options Panel",
|
"showOptionsPanel": "Show Side Panel (O or T)",
|
||||||
"showPreview": "Show Preview",
|
"showPreview": "Show Preview",
|
||||||
"shuffle": "Shuffle Seed",
|
"shuffle": "Shuffle Seed",
|
||||||
"steps": "Steps",
|
"steps": "Steps",
|
||||||
@ -896,7 +967,7 @@
|
|||||||
"tileSize": "Tile Size",
|
"tileSize": "Tile Size",
|
||||||
"toggleLoopback": "Toggle Loopback",
|
"toggleLoopback": "Toggle Loopback",
|
||||||
"type": "Type",
|
"type": "Type",
|
||||||
"upscale": "Upscale",
|
"upscale": "Upscale (Shift + U)",
|
||||||
"upscaleImage": "Upscale Image",
|
"upscaleImage": "Upscale Image",
|
||||||
"upscaling": "Upscaling",
|
"upscaling": "Upscaling",
|
||||||
"useAll": "Use All",
|
"useAll": "Use All",
|
||||||
@ -909,11 +980,20 @@
|
|||||||
"vSymmetryStep": "V Symmetry Step",
|
"vSymmetryStep": "V Symmetry Step",
|
||||||
"width": "Width"
|
"width": "Width"
|
||||||
},
|
},
|
||||||
"prompt": {
|
"dynamicPrompts": {
|
||||||
"combinatorial": "Combinatorial Generation",
|
"combinatorial": "Combinatorial Generation",
|
||||||
"dynamicPrompts": "Dynamic Prompts",
|
"dynamicPrompts": "Dynamic Prompts",
|
||||||
"enableDynamicPrompts": "Enable Dynamic Prompts",
|
"enableDynamicPrompts": "Enable Dynamic Prompts",
|
||||||
"maxPrompts": "Max Prompts"
|
"maxPrompts": "Max Prompts",
|
||||||
|
"promptsWithCount_one": "{{count}} Prompt",
|
||||||
|
"promptsWithCount_other": "{{count}} Prompts",
|
||||||
|
"seedBehaviour": {
|
||||||
|
"label": "Seed Behaviour",
|
||||||
|
"perIterationLabel": "Seed per Iteration",
|
||||||
|
"perIterationDesc": "Use a different seed for each iteration",
|
||||||
|
"perPromptLabel": "Seed per Prompt",
|
||||||
|
"perPromptDesc": "Use a different seed for each prompt"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"sdxl": {
|
"sdxl": {
|
||||||
"cfgScale": "CFG Scale",
|
"cfgScale": "CFG Scale",
|
||||||
|
@ -1,44 +0,0 @@
|
|||||||
import { Flex, Spinner, Tooltip } from '@chakra-ui/react';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
|
||||||
import { memo } from 'react';
|
|
||||||
|
|
||||||
const selector = createSelector(systemSelector, (system) => {
|
|
||||||
const { isUploading } = system;
|
|
||||||
|
|
||||||
let tooltip = '';
|
|
||||||
|
|
||||||
if (isUploading) {
|
|
||||||
tooltip = 'Uploading...';
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
tooltip,
|
|
||||||
shouldShow: isUploading,
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
export const AuxiliaryProgressIndicator = () => {
|
|
||||||
const { shouldShow, tooltip } = useAppSelector(selector);
|
|
||||||
|
|
||||||
if (!shouldShow) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex
|
|
||||||
sx={{
|
|
||||||
alignItems: 'center',
|
|
||||||
justifyContent: 'center',
|
|
||||||
color: 'base.600',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Tooltip label={tooltip} placement="right" hasArrow>
|
|
||||||
<Spinner />
|
|
||||||
</Tooltip>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(AuxiliaryProgressIndicator);
|
|
@ -1,6 +1,8 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { useQueueBack } from 'features/queue/hooks/useQueueBack';
|
||||||
|
import { useQueueFront } from 'features/queue/hooks/useQueueFront';
|
||||||
import {
|
import {
|
||||||
ctrlKeyPressed,
|
ctrlKeyPressed,
|
||||||
metaKeyPressed,
|
metaKeyPressed,
|
||||||
@ -33,6 +35,39 @@ const globalHotkeysSelector = createSelector(
|
|||||||
const GlobalHotkeys: React.FC = () => {
|
const GlobalHotkeys: React.FC = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { shift, ctrl, meta } = useAppSelector(globalHotkeysSelector);
|
const { shift, ctrl, meta } = useAppSelector(globalHotkeysSelector);
|
||||||
|
const {
|
||||||
|
queueBack,
|
||||||
|
isDisabled: isDisabledQueueBack,
|
||||||
|
isLoading: isLoadingQueueBack,
|
||||||
|
} = useQueueBack();
|
||||||
|
|
||||||
|
useHotkeys(
|
||||||
|
['ctrl+enter', 'meta+enter'],
|
||||||
|
queueBack,
|
||||||
|
{
|
||||||
|
enabled: () => !isDisabledQueueBack && !isLoadingQueueBack,
|
||||||
|
preventDefault: true,
|
||||||
|
enableOnFormTags: ['input', 'textarea', 'select'],
|
||||||
|
},
|
||||||
|
[queueBack, isDisabledQueueBack, isLoadingQueueBack]
|
||||||
|
);
|
||||||
|
|
||||||
|
const {
|
||||||
|
queueFront,
|
||||||
|
isDisabled: isDisabledQueueFront,
|
||||||
|
isLoading: isLoadingQueueFront,
|
||||||
|
} = useQueueFront();
|
||||||
|
|
||||||
|
useHotkeys(
|
||||||
|
['ctrl+shift+enter', 'meta+shift+enter'],
|
||||||
|
queueFront,
|
||||||
|
{
|
||||||
|
enabled: () => !isDisabledQueueFront && !isLoadingQueueFront,
|
||||||
|
preventDefault: true,
|
||||||
|
enableOnFormTags: ['input', 'textarea', 'select'],
|
||||||
|
},
|
||||||
|
[queueFront, isDisabledQueueFront, isLoadingQueueFront]
|
||||||
|
);
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
'*',
|
'*',
|
||||||
|
@ -17,6 +17,7 @@ import '../../i18n';
|
|||||||
import AppDndContext from '../../features/dnd/components/AppDndContext';
|
import AppDndContext from '../../features/dnd/components/AppDndContext';
|
||||||
import { $customStarUI, CustomStarUi } from 'app/store/nanostores/customStarUI';
|
import { $customStarUI, CustomStarUi } from 'app/store/nanostores/customStarUI';
|
||||||
import { $headerComponent } from 'app/store/nanostores/headerComponent';
|
import { $headerComponent } from 'app/store/nanostores/headerComponent';
|
||||||
|
import { $queueId, DEFAULT_QUEUE_ID } from 'features/queue/store/nanoStores';
|
||||||
|
|
||||||
const App = lazy(() => import('./App'));
|
const App = lazy(() => import('./App'));
|
||||||
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
|
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
|
||||||
@ -28,6 +29,7 @@ interface Props extends PropsWithChildren {
|
|||||||
headerComponent?: ReactNode;
|
headerComponent?: ReactNode;
|
||||||
middleware?: Middleware[];
|
middleware?: Middleware[];
|
||||||
projectId?: string;
|
projectId?: string;
|
||||||
|
queueId?: string;
|
||||||
selectedImage?: {
|
selectedImage?: {
|
||||||
imageName: string;
|
imageName: string;
|
||||||
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
||||||
@ -42,6 +44,7 @@ const InvokeAIUI = ({
|
|||||||
headerComponent,
|
headerComponent,
|
||||||
middleware,
|
middleware,
|
||||||
projectId,
|
projectId,
|
||||||
|
queueId,
|
||||||
selectedImage,
|
selectedImage,
|
||||||
customStarUi,
|
customStarUi,
|
||||||
}: Props) => {
|
}: Props) => {
|
||||||
@ -61,6 +64,11 @@ const InvokeAIUI = ({
|
|||||||
$projectId.set(projectId);
|
$projectId.set(projectId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// configure API client project header
|
||||||
|
if (queueId) {
|
||||||
|
$queueId.set(queueId);
|
||||||
|
}
|
||||||
|
|
||||||
// reset dynamically added middlewares
|
// reset dynamically added middlewares
|
||||||
resetMiddlewares();
|
resetMiddlewares();
|
||||||
|
|
||||||
@ -81,8 +89,9 @@ const InvokeAIUI = ({
|
|||||||
$baseUrl.set(undefined);
|
$baseUrl.set(undefined);
|
||||||
$authToken.set(undefined);
|
$authToken.set(undefined);
|
||||||
$projectId.set(undefined);
|
$projectId.set(undefined);
|
||||||
|
$queueId.set(DEFAULT_QUEUE_ID);
|
||||||
};
|
};
|
||||||
}, [apiUrl, token, middleware, projectId]);
|
}, [apiUrl, token, middleware, projectId, queueId]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (customStarUi) {
|
if (customStarUi) {
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import { useToast } from '@chakra-ui/react';
|
import { useToast } from '@chakra-ui/react';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { toastQueueSelector } from 'features/system/store/systemSelectors';
|
|
||||||
import { addToast, clearToastQueue } from 'features/system/store/systemSlice';
|
import { addToast, clearToastQueue } from 'features/system/store/systemSlice';
|
||||||
import { MakeToastArg, makeToast } from 'features/system/util/makeToast';
|
import { MakeToastArg, makeToast } from 'features/system/util/makeToast';
|
||||||
import { memo, useCallback, useEffect } from 'react';
|
import { memo, useCallback, useEffect } from 'react';
|
||||||
@ -11,7 +10,7 @@ import { memo, useCallback, useEffect } from 'react';
|
|||||||
*/
|
*/
|
||||||
const Toaster = () => {
|
const Toaster = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const toastQueue = useAppSelector(toastQueueSelector);
|
const toastQueue = useAppSelector((state) => state.system.toastQueue);
|
||||||
const toast = useToast();
|
const toast = useToast();
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
toastQueue.forEach((t) => {
|
toastQueue.forEach((t) => {
|
||||||
|
@ -20,6 +20,7 @@ export type LoggerNamespace =
|
|||||||
| 'system'
|
| 'system'
|
||||||
| 'socketio'
|
| 'socketio'
|
||||||
| 'session'
|
| 'session'
|
||||||
|
| 'queue'
|
||||||
| 'dnd';
|
| 'dnd';
|
||||||
|
|
||||||
export const logger = (namespace: LoggerNamespace) =>
|
export const logger = (namespace: LoggerNamespace) =>
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { createLogWriter } from '@roarr/browser-log-writer';
|
import { createLogWriter } from '@roarr/browser-log-writer';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
import { useEffect, useMemo } from 'react';
|
import { useEffect, useMemo } from 'react';
|
||||||
import { ROARR, Roarr } from 'roarr';
|
import { ROARR, Roarr } from 'roarr';
|
||||||
@ -14,8 +14,8 @@ import {
|
|||||||
} from './logger';
|
} from './logger';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
systemSelector,
|
stateSelector,
|
||||||
(system) => {
|
({ system }) => {
|
||||||
const { consoleLogLevel, shouldLogToConsole } = system;
|
const { consoleLogLevel, shouldLogToConsole } = system;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -1,4 +1,10 @@
|
|||||||
import { createAction } from '@reduxjs/toolkit';
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
|
import { BatchConfig } from 'services/api/types';
|
||||||
|
|
||||||
export const userInvoked = createAction<InvokeTabName>('app/userInvoked');
|
export const enqueueRequested = createAction<{
|
||||||
|
tabName: InvokeTabName;
|
||||||
|
prepend: boolean;
|
||||||
|
}>('app/enqueueRequested');
|
||||||
|
|
||||||
|
export const batchEnqueued = createAction<BatchConfig>('app/batchEnqueued');
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist';
|
import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist';
|
||||||
import { controlNetDenylist } from 'features/controlNet/store/controlNetDenylist';
|
import { controlNetDenylist } from 'features/controlNet/store/controlNetDenylist';
|
||||||
|
import { dynamicPromptsPersistDenylist } from 'features/dynamicPrompts/store/dynamicPromptsPersistDenylist';
|
||||||
import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist';
|
import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist';
|
||||||
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
|
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
|
||||||
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
|
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
|
||||||
@ -20,6 +21,7 @@ const serializationDenylist: {
|
|||||||
system: systemPersistDenylist,
|
system: systemPersistDenylist,
|
||||||
ui: uiPersistDenylist,
|
ui: uiPersistDenylist,
|
||||||
controlNet: controlNetDenylist,
|
controlNet: controlNetDenylist,
|
||||||
|
dynamicPrompts: dynamicPromptsPersistDenylist,
|
||||||
};
|
};
|
||||||
|
|
||||||
export const serialize: SerializeFunction = (data, key) => {
|
export const serialize: SerializeFunction = (data, key) => {
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
import { initialCanvasState } from 'features/canvas/store/canvasSlice';
|
import { initialCanvasState } from 'features/canvas/store/canvasSlice';
|
||||||
import { initialControlNetState } from 'features/controlNet/store/controlNetSlice';
|
import { initialControlNetState } from 'features/controlNet/store/controlNetSlice';
|
||||||
|
import { initialDynamicPromptsState } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||||
import { initialGalleryState } from 'features/gallery/store/gallerySlice';
|
import { initialGalleryState } from 'features/gallery/store/gallerySlice';
|
||||||
import { initialNodesState } from 'features/nodes/store/nodesSlice';
|
import { initialNodesState } from 'features/nodes/store/nodesSlice';
|
||||||
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
||||||
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
|
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
|
||||||
|
import { initialSDXLState } from 'features/sdxl/store/sdxlSlice';
|
||||||
import { initialConfigState } from 'features/system/store/configSlice';
|
import { initialConfigState } from 'features/system/store/configSlice';
|
||||||
import { initialSystemState } from 'features/system/store/systemSlice';
|
import { initialSystemState } from 'features/system/store/systemSlice';
|
||||||
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
|
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
|
||||||
@ -24,6 +26,8 @@ const initialStates: {
|
|||||||
ui: initialUIState,
|
ui: initialUIState,
|
||||||
hotkeys: initialHotkeysState,
|
hotkeys: initialHotkeysState,
|
||||||
controlNet: initialControlNetState,
|
controlNet: initialControlNetState,
|
||||||
|
dynamicPrompts: initialDynamicPromptsState,
|
||||||
|
sdxl: initialSDXLState,
|
||||||
};
|
};
|
||||||
|
|
||||||
export const unserialize: UnserializeFunction = (data, key) => {
|
export const unserialize: UnserializeFunction = (data, key) => {
|
||||||
|
@ -9,6 +9,7 @@ import {
|
|||||||
import type { AppDispatch, RootState } from '../../store';
|
import type { AppDispatch, RootState } from '../../store';
|
||||||
import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener';
|
import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener';
|
||||||
import { addFirstListImagesListener } from './listeners/addFirstListImagesListener.ts';
|
import { addFirstListImagesListener } from './listeners/addFirstListImagesListener.ts';
|
||||||
|
import { addAnyEnqueuedListener } from './listeners/anyEnqueued';
|
||||||
import { addAppConfigReceivedListener } from './listeners/appConfigReceived';
|
import { addAppConfigReceivedListener } from './listeners/appConfigReceived';
|
||||||
import { addAppStartedListener } from './listeners/appStarted';
|
import { addAppStartedListener } from './listeners/appStarted';
|
||||||
import { addDeleteBoardAndImagesFulfilledListener } from './listeners/boardAndImagesDeleted';
|
import { addDeleteBoardAndImagesFulfilledListener } from './listeners/boardAndImagesDeleted';
|
||||||
@ -22,6 +23,9 @@ import { addCanvasMergedListener } from './listeners/canvasMerged';
|
|||||||
import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery';
|
import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery';
|
||||||
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
|
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
|
||||||
import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed';
|
import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed';
|
||||||
|
import { addEnqueueRequestedCanvasListener } from './listeners/enqueueRequestedCanvas';
|
||||||
|
import { addEnqueueRequestedLinear } from './listeners/enqueueRequestedLinear';
|
||||||
|
import { addEnqueueRequestedNodes } from './listeners/enqueueRequestedNodes';
|
||||||
import {
|
import {
|
||||||
addImageAddedToBoardFulfilledListener,
|
addImageAddedToBoardFulfilledListener,
|
||||||
addImageAddedToBoardRejectedListener,
|
addImageAddedToBoardRejectedListener,
|
||||||
@ -48,6 +52,7 @@ import { addImagesUnstarredListener } from './listeners/imagesUnstarred';
|
|||||||
import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
|
import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
|
||||||
import { addModelSelectedListener } from './listeners/modelSelected';
|
import { addModelSelectedListener } from './listeners/modelSelected';
|
||||||
import { addModelsLoadedListener } from './listeners/modelsLoaded';
|
import { addModelsLoadedListener } from './listeners/modelsLoaded';
|
||||||
|
import { addDynamicPromptsListener } from './listeners/promptChanged';
|
||||||
import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema';
|
import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema';
|
||||||
import {
|
import {
|
||||||
addSessionCanceledFulfilledListener,
|
addSessionCanceledFulfilledListener,
|
||||||
@ -64,7 +69,6 @@ import {
|
|||||||
addSessionInvokedPendingListener,
|
addSessionInvokedPendingListener,
|
||||||
addSessionInvokedRejectedListener,
|
addSessionInvokedRejectedListener,
|
||||||
} from './listeners/sessionInvoked';
|
} from './listeners/sessionInvoked';
|
||||||
import { addSessionReadyToInvokeListener } from './listeners/sessionReadyToInvoke';
|
|
||||||
import { addSocketConnectedEventListener as addSocketConnectedListener } from './listeners/socketio/socketConnected';
|
import { addSocketConnectedEventListener as addSocketConnectedListener } from './listeners/socketio/socketConnected';
|
||||||
import { addSocketDisconnectedEventListener as addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected';
|
import { addSocketDisconnectedEventListener as addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected';
|
||||||
import { addGeneratorProgressEventListener as addGeneratorProgressListener } from './listeners/socketio/socketGeneratorProgress';
|
import { addGeneratorProgressEventListener as addGeneratorProgressListener } from './listeners/socketio/socketGeneratorProgress';
|
||||||
@ -74,16 +78,13 @@ import { addInvocationErrorEventListener as addInvocationErrorListener } from '.
|
|||||||
import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError';
|
import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError';
|
||||||
import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted';
|
import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted';
|
||||||
import { addModelLoadEventListener } from './listeners/socketio/socketModelLoad';
|
import { addModelLoadEventListener } from './listeners/socketio/socketModelLoad';
|
||||||
|
import { addSocketQueueItemStatusChangedEventListener } from './listeners/socketio/socketQueueItemStatusChanged';
|
||||||
import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError';
|
import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError';
|
||||||
import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';
|
import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';
|
||||||
import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed';
|
import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed';
|
||||||
import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved';
|
import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved';
|
||||||
import { addTabChangedListener } from './listeners/tabChanged';
|
import { addTabChangedListener } from './listeners/tabChanged';
|
||||||
import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
|
import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
|
||||||
import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
|
||||||
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
|
||||||
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
|
||||||
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
|
||||||
import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
|
import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
|
||||||
|
|
||||||
export const listenerMiddleware = createListenerMiddleware();
|
export const listenerMiddleware = createListenerMiddleware();
|
||||||
@ -131,11 +132,10 @@ addImagesStarredListener();
|
|||||||
addImagesUnstarredListener();
|
addImagesUnstarredListener();
|
||||||
|
|
||||||
// User Invoked
|
// User Invoked
|
||||||
addUserInvokedCanvasListener();
|
addEnqueueRequestedCanvasListener();
|
||||||
addUserInvokedNodesListener();
|
addEnqueueRequestedNodes();
|
||||||
addUserInvokedTextToImageListener();
|
addEnqueueRequestedLinear();
|
||||||
addUserInvokedImageToImageListener();
|
addAnyEnqueuedListener();
|
||||||
addSessionReadyToInvokeListener();
|
|
||||||
|
|
||||||
// Canvas actions
|
// Canvas actions
|
||||||
addCanvasSavedToGalleryListener();
|
addCanvasSavedToGalleryListener();
|
||||||
@ -173,6 +173,7 @@ addSocketUnsubscribedListener();
|
|||||||
addModelLoadEventListener();
|
addModelLoadEventListener();
|
||||||
addSessionRetrievalErrorEventListener();
|
addSessionRetrievalErrorEventListener();
|
||||||
addInvocationRetrievalErrorEventListener();
|
addInvocationRetrievalErrorEventListener();
|
||||||
|
addSocketQueueItemStatusChangedEventListener();
|
||||||
|
|
||||||
// Session Created
|
// Session Created
|
||||||
addSessionCreatedPendingListener();
|
addSessionCreatedPendingListener();
|
||||||
@ -223,3 +224,6 @@ addUpscaleRequestedListener();
|
|||||||
|
|
||||||
// Tab Change
|
// Tab Change
|
||||||
addTabChangedListener();
|
addTabChangedListener();
|
||||||
|
|
||||||
|
// Dynamic prompts
|
||||||
|
addDynamicPromptsListener();
|
||||||
|
@ -1,39 +1,53 @@
|
|||||||
|
import { isAnyOf } from '@reduxjs/toolkit';
|
||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import { commitStagingAreaImage } from 'features/canvas/store/canvasSlice';
|
import {
|
||||||
import { sessionCanceled } from 'services/api/thunks/session';
|
canvasBatchesAndSessionsReset,
|
||||||
|
commitStagingAreaImage,
|
||||||
|
discardStagedImages,
|
||||||
|
} from 'features/canvas/store/canvasSlice';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { t } from 'i18next';
|
||||||
|
import { queueApi } from 'services/api/endpoints/queue';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
|
const matcher = isAnyOf(commitStagingAreaImage, discardStagedImages);
|
||||||
|
|
||||||
export const addCommitStagingAreaImageListener = () => {
|
export const addCommitStagingAreaImageListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: commitStagingAreaImage,
|
matcher,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: async (_, { dispatch, getState }) => {
|
||||||
const log = logger('canvas');
|
const log = logger('canvas');
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const { sessionId: session_id, isProcessing } = state.system;
|
const { batchIds } = state.canvas;
|
||||||
const canvasSessionId = action.payload;
|
|
||||||
|
|
||||||
if (!isProcessing) {
|
try {
|
||||||
// Only need to cancel if we are processing
|
const req = dispatch(
|
||||||
return;
|
queueApi.endpoints.cancelByBatchIds.initiate(
|
||||||
}
|
{ batch_ids: batchIds },
|
||||||
|
{ fixedCacheKey: 'cancelByBatchIds' }
|
||||||
if (!canvasSessionId) {
|
)
|
||||||
log.debug('No canvas session, skipping cancel');
|
);
|
||||||
return;
|
const { canceled } = await req.unwrap();
|
||||||
}
|
req.reset();
|
||||||
|
if (canceled > 0) {
|
||||||
if (canvasSessionId !== session_id) {
|
log.debug(`Canceled ${canceled} canvas batches`);
|
||||||
log.debug(
|
dispatch(
|
||||||
{
|
addToast({
|
||||||
canvasSessionId,
|
title: t('queue.cancelBatchSucceeded'),
|
||||||
session_id,
|
status: 'success',
|
||||||
},
|
})
|
||||||
'Canvas session does not match global session, skipping cancel'
|
);
|
||||||
|
}
|
||||||
|
dispatch(canvasBatchesAndSessionsReset());
|
||||||
|
} catch {
|
||||||
|
log.error('Failed to cancel canvas batches');
|
||||||
|
dispatch(
|
||||||
|
addToast({
|
||||||
|
title: t('queue.cancelBatchFailed'),
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
);
|
);
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatch(sessionCanceled({ session_id }));
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -0,0 +1,27 @@
|
|||||||
|
import { isAnyOf } from '@reduxjs/toolkit';
|
||||||
|
import { queueApi } from 'services/api/endpoints/queue';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
|
const matcher = isAnyOf(
|
||||||
|
queueApi.endpoints.enqueueBatch.matchFulfilled,
|
||||||
|
queueApi.endpoints.enqueueGraph.matchFulfilled
|
||||||
|
);
|
||||||
|
|
||||||
|
export const addAnyEnqueuedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
matcher,
|
||||||
|
effect: async (_, { dispatch, getState }) => {
|
||||||
|
const { data } = queueApi.endpoints.getQueueStatus.select()(getState());
|
||||||
|
|
||||||
|
if (!data || data.processor.is_started) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
queueApi.endpoints.resumeProcessor.initiate(undefined, {
|
||||||
|
fixedCacheKey: 'resumeProcessor',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -52,11 +52,9 @@ const predicate: AnyListenerPredicate<RootState> = (
|
|||||||
|
|
||||||
const isProcessorSelected = processorType !== 'none';
|
const isProcessorSelected = processorType !== 'none';
|
||||||
|
|
||||||
const isBusy = state.system.isProcessing;
|
|
||||||
|
|
||||||
const hasControlImage = Boolean(controlImage);
|
const hasControlImage = Boolean(controlImage);
|
||||||
|
|
||||||
return isProcessorSelected && !isBusy && hasControlImage;
|
return isProcessorSelected && hasControlImage;
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
|
import { parseify } from 'common/util/serialize';
|
||||||
import { controlNetImageProcessed } from 'features/controlNet/store/actions';
|
import { controlNetImageProcessed } from 'features/controlNet/store/actions';
|
||||||
import { controlNetProcessedImageChanged } from 'features/controlNet/store/controlNetSlice';
|
import { controlNetProcessedImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||||
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
import { SAVE_IMAGE } from 'features/nodes/util/graphBuilders/constants';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { t } from 'i18next';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
|
import { queueApi } from 'services/api/endpoints/queue';
|
||||||
import { isImageOutput } from 'services/api/guards';
|
import { isImageOutput } from 'services/api/guards';
|
||||||
import { sessionCreated } from 'services/api/thunks/session';
|
|
||||||
import { Graph, ImageDTO } from 'services/api/types';
|
import { Graph, ImageDTO } from 'services/api/types';
|
||||||
import { socketInvocationComplete } from 'services/events/actions';
|
import { socketInvocationComplete } from 'services/events/actions';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
@ -31,51 +34,84 @@ export const addControlNetImageProcessedListener = () => {
|
|||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
image: { image_name: controlNet.controlImage },
|
image: { image_name: controlNet.controlImage },
|
||||||
},
|
},
|
||||||
|
[SAVE_IMAGE]: {
|
||||||
|
id: SAVE_IMAGE,
|
||||||
|
type: 'save_image',
|
||||||
|
is_intermediate: true,
|
||||||
|
use_cache: false,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
|
edges: [
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: controlNet.processorNode.id,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: SAVE_IMAGE,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
};
|
};
|
||||||
|
try {
|
||||||
// Create a session to run the graph & wait til it's ready to invoke
|
const req = dispatch(
|
||||||
const sessionCreatedAction = dispatch(sessionCreated({ graph }));
|
queueApi.endpoints.enqueueGraph.initiate(
|
||||||
const [sessionCreatedFulfilledAction] = await take(
|
{ graph, prepend: true },
|
||||||
(action): action is ReturnType<typeof sessionCreated.fulfilled> =>
|
{
|
||||||
sessionCreated.fulfilled.match(action) &&
|
fixedCacheKey: 'enqueueGraph',
|
||||||
action.meta.requestId === sessionCreatedAction.requestId
|
}
|
||||||
);
|
)
|
||||||
|
|
||||||
const sessionId = sessionCreatedFulfilledAction.payload.id;
|
|
||||||
|
|
||||||
// Invoke the session & wait til it's complete
|
|
||||||
dispatch(sessionReadyToInvoke());
|
|
||||||
const [invocationCompleteAction] = await take(
|
|
||||||
(action): action is ReturnType<typeof socketInvocationComplete> =>
|
|
||||||
socketInvocationComplete.match(action) &&
|
|
||||||
action.payload.data.graph_execution_state_id === sessionId
|
|
||||||
);
|
|
||||||
|
|
||||||
// We still have to check the output type
|
|
||||||
if (isImageOutput(invocationCompleteAction.payload.data.result)) {
|
|
||||||
const { image_name } =
|
|
||||||
invocationCompleteAction.payload.data.result.image;
|
|
||||||
|
|
||||||
// Wait for the ImageDTO to be received
|
|
||||||
const [{ payload }] = await take(
|
|
||||||
(action) =>
|
|
||||||
imagesApi.endpoints.getImageDTO.matchFulfilled(action) &&
|
|
||||||
action.payload.image_name === image_name
|
|
||||||
);
|
);
|
||||||
|
const enqueueResult = await req.unwrap();
|
||||||
const processedControlImage = payload as ImageDTO;
|
req.reset();
|
||||||
|
console.log(enqueueResult.queue_item.session_id);
|
||||||
log.debug(
|
log.debug(
|
||||||
{ controlNetId: action.payload, processedControlImage },
|
{ enqueueResult: parseify(enqueueResult) },
|
||||||
'ControlNet image processed'
|
t('queue.graphQueued')
|
||||||
);
|
);
|
||||||
|
|
||||||
// Update the processed image in the store
|
const [invocationCompleteAction] = await take(
|
||||||
|
(action): action is ReturnType<typeof socketInvocationComplete> =>
|
||||||
|
socketInvocationComplete.match(action) &&
|
||||||
|
action.payload.data.graph_execution_state_id ===
|
||||||
|
enqueueResult.queue_item.session_id &&
|
||||||
|
action.payload.data.source_node_id === SAVE_IMAGE
|
||||||
|
);
|
||||||
|
|
||||||
|
// We still have to check the output type
|
||||||
|
if (isImageOutput(invocationCompleteAction.payload.data.result)) {
|
||||||
|
const { image_name } =
|
||||||
|
invocationCompleteAction.payload.data.result.image;
|
||||||
|
|
||||||
|
// Wait for the ImageDTO to be received
|
||||||
|
const [{ payload }] = await take(
|
||||||
|
(action) =>
|
||||||
|
imagesApi.endpoints.getImageDTO.matchFulfilled(action) &&
|
||||||
|
action.payload.image_name === image_name
|
||||||
|
);
|
||||||
|
|
||||||
|
const processedControlImage = payload as ImageDTO;
|
||||||
|
|
||||||
|
log.debug(
|
||||||
|
{ controlNetId: action.payload, processedControlImage },
|
||||||
|
'ControlNet image processed'
|
||||||
|
);
|
||||||
|
|
||||||
|
// Update the processed image in the store
|
||||||
|
dispatch(
|
||||||
|
controlNetProcessedImageChanged({
|
||||||
|
controlNetId,
|
||||||
|
processedControlImage: processedControlImage.image_name,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
log.error({ graph: parseify(graph) }, t('queue.graphFailedToQueue'));
|
||||||
dispatch(
|
dispatch(
|
||||||
controlNetProcessedImageChanged({
|
addToast({
|
||||||
controlNetId,
|
title: t('queue.graphFailedToQueue'),
|
||||||
processedControlImage: processedControlImage.image_name,
|
status: 'error',
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import { userInvoked } from 'app/store/actions';
|
import { enqueueRequested } from 'app/store/actions';
|
||||||
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
||||||
import { parseify } from 'common/util/serialize';
|
import { parseify } from 'common/util/serialize';
|
||||||
import {
|
import {
|
||||||
canvasSessionIdChanged,
|
canvasBatchIdAdded,
|
||||||
stagingAreaInitialized,
|
stagingAreaInitialized,
|
||||||
} from 'features/canvas/store/canvasSlice';
|
} from 'features/canvas/store/canvasSlice';
|
||||||
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
|
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
|
||||||
@ -11,9 +11,11 @@ import { getCanvasData } from 'features/canvas/util/getCanvasData';
|
|||||||
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
|
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
|
||||||
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
||||||
import { buildCanvasGraph } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
|
import { buildCanvasGraph } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
|
||||||
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
import { prepareLinearUIBatch } from 'features/nodes/util/graphBuilders/buildLinearBatchConfig';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { t } from 'i18next';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
import { sessionCreated } from 'services/api/thunks/session';
|
import { queueApi } from 'services/api/endpoints/queue';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
@ -30,13 +32,14 @@ import { startAppListening } from '..';
|
|||||||
* 8. Initialize the staging area if not yet initialized
|
* 8. Initialize the staging area if not yet initialized
|
||||||
* 9. Dispatch the sessionReadyToInvoke action to invoke the session
|
* 9. Dispatch the sessionReadyToInvoke action to invoke the session
|
||||||
*/
|
*/
|
||||||
export const addUserInvokedCanvasListener = () => {
|
export const addEnqueueRequestedCanvasListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
|
||||||
userInvoked.match(action) && action.payload === 'unifiedCanvas',
|
enqueueRequested.match(action) &&
|
||||||
effect: async (action, { getState, dispatch, take }) => {
|
action.payload.tabName === 'unifiedCanvas',
|
||||||
const log = logger('session');
|
effect: async (action, { getState, dispatch }) => {
|
||||||
|
const log = logger('queue');
|
||||||
|
const { prepend } = action.payload;
|
||||||
const state = getState();
|
const state = getState();
|
||||||
|
|
||||||
const {
|
const {
|
||||||
@ -125,57 +128,59 @@ export const addUserInvokedCanvasListener = () => {
|
|||||||
// currently this action is just listened to for logging
|
// currently this action is just listened to for logging
|
||||||
dispatch(canvasGraphBuilt(graph));
|
dispatch(canvasGraphBuilt(graph));
|
||||||
|
|
||||||
// Create the session, store the request id
|
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
|
||||||
const { requestId: sessionCreatedRequestId } = dispatch(
|
|
||||||
sessionCreated({ graph })
|
|
||||||
);
|
|
||||||
|
|
||||||
// Take the session created action, matching by its request id
|
try {
|
||||||
const [sessionCreatedAction] = await take(
|
const req = dispatch(
|
||||||
(action): action is ReturnType<typeof sessionCreated.fulfilled> =>
|
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||||
sessionCreated.fulfilled.match(action) &&
|
fixedCacheKey: 'enqueueBatch',
|
||||||
action.meta.requestId === sessionCreatedRequestId
|
})
|
||||||
);
|
);
|
||||||
const session_id = sessionCreatedAction.payload.id;
|
|
||||||
|
const enqueueResult = await req.unwrap();
|
||||||
|
req.reset();
|
||||||
|
|
||||||
|
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
|
||||||
|
|
||||||
|
const batchId = enqueueResult.batch.batch_id as string; // we know the is a string, backend provides it
|
||||||
|
|
||||||
|
// Prep the canvas staging area if it is not yet initialized
|
||||||
|
if (!state.canvas.layerState.stagingArea.boundingBox) {
|
||||||
|
dispatch(
|
||||||
|
stagingAreaInitialized({
|
||||||
|
boundingBox: {
|
||||||
|
...state.canvas.boundingBoxCoordinates,
|
||||||
|
...state.canvas.boundingBoxDimensions,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Associate the session with the canvas session ID
|
||||||
|
dispatch(canvasBatchIdAdded(batchId));
|
||||||
|
|
||||||
// Associate the init image with the session, now that we have the session ID
|
|
||||||
if (['img2img', 'inpaint'].includes(generationMode) && canvasInitImage) {
|
|
||||||
dispatch(
|
dispatch(
|
||||||
imagesApi.endpoints.changeImageSessionId.initiate({
|
addToast({
|
||||||
imageDTO: canvasInitImage,
|
title: t('queue.batchQueued'),
|
||||||
session_id,
|
description: t('queue.batchQueuedDesc', {
|
||||||
|
item_count: enqueueResult.enqueued,
|
||||||
|
direction: prepend ? t('queue.front') : t('queue.back'),
|
||||||
|
}),
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
} catch {
|
||||||
|
log.error(
|
||||||
|
{ batchConfig: parseify(batchConfig) },
|
||||||
|
t('queue.batchFailedToQueue')
|
||||||
|
);
|
||||||
|
dispatch(
|
||||||
|
addToast({
|
||||||
|
title: t('queue.batchFailedToQueue'),
|
||||||
|
status: 'error',
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Associate the mask image with the session, now that we have the session ID
|
|
||||||
if (['inpaint'].includes(generationMode) && canvasMaskImage) {
|
|
||||||
dispatch(
|
|
||||||
imagesApi.endpoints.changeImageSessionId.initiate({
|
|
||||||
imageDTO: canvasMaskImage,
|
|
||||||
session_id,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prep the canvas staging area if it is not yet initialized
|
|
||||||
if (!state.canvas.layerState.stagingArea.boundingBox) {
|
|
||||||
dispatch(
|
|
||||||
stagingAreaInitialized({
|
|
||||||
sessionId: session_id,
|
|
||||||
boundingBox: {
|
|
||||||
...state.canvas.boundingBoxCoordinates,
|
|
||||||
...state.canvas.boundingBoxDimensions,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Flag the session with the canvas session ID
|
|
||||||
dispatch(canvasSessionIdChanged(session_id));
|
|
||||||
|
|
||||||
// We are ready to invoke the session!
|
|
||||||
dispatch(sessionReadyToInvoke());
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
@ -0,0 +1,78 @@
|
|||||||
|
import { logger } from 'app/logging/logger';
|
||||||
|
import { enqueueRequested } from 'app/store/actions';
|
||||||
|
import { parseify } from 'common/util/serialize';
|
||||||
|
import { prepareLinearUIBatch } from 'features/nodes/util/graphBuilders/buildLinearBatchConfig';
|
||||||
|
import { buildLinearImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearImageToImageGraph';
|
||||||
|
import { buildLinearSDXLImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph';
|
||||||
|
import { buildLinearSDXLTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph';
|
||||||
|
import { buildLinearTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearTextToImageGraph';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { t } from 'i18next';
|
||||||
|
import { queueApi } from 'services/api/endpoints/queue';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
|
export const addEnqueueRequestedLinear = () => {
|
||||||
|
startAppListening({
|
||||||
|
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
|
||||||
|
enqueueRequested.match(action) &&
|
||||||
|
(action.payload.tabName === 'txt2img' ||
|
||||||
|
action.payload.tabName === 'img2img'),
|
||||||
|
effect: async (action, { getState, dispatch }) => {
|
||||||
|
const log = logger('queue');
|
||||||
|
const state = getState();
|
||||||
|
const model = state.generation.model;
|
||||||
|
const { prepend } = action.payload;
|
||||||
|
|
||||||
|
let graph;
|
||||||
|
|
||||||
|
if (model && model.base_model === 'sdxl') {
|
||||||
|
if (action.payload.tabName === 'txt2img') {
|
||||||
|
graph = buildLinearSDXLTextToImageGraph(state);
|
||||||
|
} else {
|
||||||
|
graph = buildLinearSDXLImageToImageGraph(state);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (action.payload.tabName === 'txt2img') {
|
||||||
|
graph = buildLinearTextToImageGraph(state);
|
||||||
|
} else {
|
||||||
|
graph = buildLinearImageToImageGraph(state);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const req = dispatch(
|
||||||
|
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||||
|
fixedCacheKey: 'enqueueBatch',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
const enqueueResult = await req.unwrap();
|
||||||
|
req.reset();
|
||||||
|
|
||||||
|
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
|
||||||
|
dispatch(
|
||||||
|
addToast({
|
||||||
|
title: t('queue.batchQueued'),
|
||||||
|
description: t('queue.batchQueuedDesc', {
|
||||||
|
item_count: enqueueResult.enqueued,
|
||||||
|
direction: prepend ? t('queue.front') : t('queue.back'),
|
||||||
|
}),
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
} catch {
|
||||||
|
log.error(
|
||||||
|
{ batchConfig: parseify(batchConfig) },
|
||||||
|
t('queue.batchFailedToQueue')
|
||||||
|
);
|
||||||
|
dispatch(
|
||||||
|
addToast({
|
||||||
|
title: t('queue.batchFailedToQueue'),
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,62 @@
|
|||||||
|
import { logger } from 'app/logging/logger';
|
||||||
|
import { enqueueRequested } from 'app/store/actions';
|
||||||
|
import { parseify } from 'common/util/serialize';
|
||||||
|
import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGraph';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { t } from 'i18next';
|
||||||
|
import { queueApi } from 'services/api/endpoints/queue';
|
||||||
|
import { BatchConfig } from 'services/api/types';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
|
export const addEnqueueRequestedNodes = () => {
|
||||||
|
startAppListening({
|
||||||
|
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
|
||||||
|
enqueueRequested.match(action) && action.payload.tabName === 'nodes',
|
||||||
|
effect: async (action, { getState, dispatch }) => {
|
||||||
|
const log = logger('queue');
|
||||||
|
const state = getState();
|
||||||
|
const { prepend } = action.payload;
|
||||||
|
const graph = buildNodesGraph(state.nodes);
|
||||||
|
const batchConfig: BatchConfig = {
|
||||||
|
batch: {
|
||||||
|
graph,
|
||||||
|
runs: state.generation.iterations,
|
||||||
|
},
|
||||||
|
prepend: action.payload.prepend,
|
||||||
|
};
|
||||||
|
|
||||||
|
try {
|
||||||
|
const req = dispatch(
|
||||||
|
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||||
|
fixedCacheKey: 'enqueueBatch',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
const enqueueResult = await req.unwrap();
|
||||||
|
req.reset();
|
||||||
|
|
||||||
|
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
|
||||||
|
dispatch(
|
||||||
|
addToast({
|
||||||
|
title: t('queue.batchQueued'),
|
||||||
|
description: t('queue.batchQueuedDesc', {
|
||||||
|
item_count: enqueueResult.enqueued,
|
||||||
|
direction: prepend ? t('queue.front') : t('queue.back'),
|
||||||
|
}),
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
} catch {
|
||||||
|
log.error(
|
||||||
|
{ batchConfig: parseify(batchConfig) },
|
||||||
|
'Failed to enqueue batch'
|
||||||
|
);
|
||||||
|
dispatch(
|
||||||
|
addToast({
|
||||||
|
title: t('queue.batchFailedToQueue'),
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,67 @@
|
|||||||
|
import { isAnyOf } from '@reduxjs/toolkit';
|
||||||
|
import {
|
||||||
|
combinatorialToggled,
|
||||||
|
isErrorChanged,
|
||||||
|
isLoadingChanged,
|
||||||
|
maxPromptsChanged,
|
||||||
|
maxPromptsReset,
|
||||||
|
parsingErrorChanged,
|
||||||
|
promptsChanged,
|
||||||
|
} from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||||
|
import { setPositivePrompt } from 'features/parameters/store/generationSlice';
|
||||||
|
import { utilitiesApi } from 'services/api/endpoints/utilities';
|
||||||
|
import { appSocketConnected } from 'services/events/actions';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
|
const matcher = isAnyOf(
|
||||||
|
setPositivePrompt,
|
||||||
|
combinatorialToggled,
|
||||||
|
maxPromptsChanged,
|
||||||
|
maxPromptsReset,
|
||||||
|
appSocketConnected
|
||||||
|
);
|
||||||
|
|
||||||
|
export const addDynamicPromptsListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
matcher,
|
||||||
|
effect: async (
|
||||||
|
action,
|
||||||
|
{ dispatch, getState, cancelActiveListeners, delay }
|
||||||
|
) => {
|
||||||
|
// debounce request
|
||||||
|
cancelActiveListeners();
|
||||||
|
await delay(1000);
|
||||||
|
|
||||||
|
const state = getState();
|
||||||
|
|
||||||
|
if (state.config.disabledFeatures.includes('dynamicPrompting')) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { positivePrompt } = state.generation;
|
||||||
|
const { maxPrompts } = state.dynamicPrompts;
|
||||||
|
|
||||||
|
dispatch(isLoadingChanged(true));
|
||||||
|
|
||||||
|
try {
|
||||||
|
const req = dispatch(
|
||||||
|
utilitiesApi.endpoints.dynamicPrompts.initiate({
|
||||||
|
prompt: positivePrompt,
|
||||||
|
max_prompts: maxPrompts,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
const res = await req.unwrap();
|
||||||
|
req.unsubscribe();
|
||||||
|
|
||||||
|
dispatch(promptsChanged(res.prompts));
|
||||||
|
dispatch(parsingErrorChanged(res.error));
|
||||||
|
dispatch(isErrorChanged(false));
|
||||||
|
dispatch(isLoadingChanged(false));
|
||||||
|
} catch {
|
||||||
|
dispatch(isErrorChanged(true));
|
||||||
|
dispatch(isLoadingChanged(false));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -1,18 +0,0 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
|
||||||
import { sessionInvoked } from 'services/api/thunks/session';
|
|
||||||
import { startAppListening } from '..';
|
|
||||||
|
|
||||||
export const addSessionReadyToInvokeListener = () => {
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: sessionReadyToInvoke,
|
|
||||||
effect: (action, { getState, dispatch }) => {
|
|
||||||
const log = logger('session');
|
|
||||||
const { sessionId: session_id } = getState().system;
|
|
||||||
if (session_id) {
|
|
||||||
log.debug({ session_id }, `Session ready to invoke (${session_id})})`);
|
|
||||||
dispatch(sessionInvoked({ session_id }));
|
|
||||||
}
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
@ -1,11 +1,9 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import { LIST_TAG } from 'services/api';
|
import { size } from 'lodash-es';
|
||||||
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
import { api } from 'services/api';
|
||||||
import { modelsApi } from 'services/api/endpoints/models';
|
|
||||||
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
||||||
import { appSocketConnected, socketConnected } from 'services/events/actions';
|
import { appSocketConnected, socketConnected } from 'services/events/actions';
|
||||||
import { startAppListening } from '../..';
|
import { startAppListening } from '../..';
|
||||||
import { size } from 'lodash-es';
|
|
||||||
|
|
||||||
export const addSocketConnectedEventListener = () => {
|
export const addSocketConnectedEventListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
@ -23,22 +21,10 @@ export const addSocketConnectedEventListener = () => {
|
|||||||
dispatch(receivedOpenAPISchema());
|
dispatch(receivedOpenAPISchema());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dispatch(api.util.resetApiState());
|
||||||
|
|
||||||
// pass along the socket event as an application action
|
// pass along the socket event as an application action
|
||||||
dispatch(appSocketConnected(action.payload));
|
dispatch(appSocketConnected(action.payload));
|
||||||
|
|
||||||
// update all server state
|
|
||||||
dispatch(
|
|
||||||
modelsApi.util.invalidateTags([
|
|
||||||
{ type: 'MainModel', id: LIST_TAG },
|
|
||||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
|
||||||
{ type: 'LoRAModel', id: LIST_TAG },
|
|
||||||
{ type: 'ControlNetModel', id: LIST_TAG },
|
|
||||||
{ type: 'VaeModel', id: LIST_TAG },
|
|
||||||
{ type: 'TextualInversionModel', id: LIST_TAG },
|
|
||||||
{ type: 'ScannedModels', id: LIST_TAG },
|
|
||||||
])
|
|
||||||
);
|
|
||||||
dispatch(appInfoApi.util.invalidateTags(['AppConfig', 'AppVersion']));
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
|
import { api } from 'services/api';
|
||||||
import {
|
import {
|
||||||
appSocketDisconnected,
|
appSocketDisconnected,
|
||||||
socketDisconnected,
|
socketDisconnected,
|
||||||
@ -11,6 +12,9 @@ export const addSocketDisconnectedEventListener = () => {
|
|||||||
effect: (action, { dispatch }) => {
|
effect: (action, { dispatch }) => {
|
||||||
const log = logger('socketio');
|
const log = logger('socketio');
|
||||||
log.debug('Disconnected');
|
log.debug('Disconnected');
|
||||||
|
|
||||||
|
dispatch(api.util.resetApiState());
|
||||||
|
|
||||||
// pass along the socket event as an application action
|
// pass along the socket event as an application action
|
||||||
dispatch(appSocketDisconnected(action.payload));
|
dispatch(appSocketDisconnected(action.payload));
|
||||||
},
|
},
|
||||||
|
@ -8,25 +8,11 @@ import { startAppListening } from '../..';
|
|||||||
export const addGeneratorProgressEventListener = () => {
|
export const addGeneratorProgressEventListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: socketGeneratorProgress,
|
actionCreator: socketGeneratorProgress,
|
||||||
effect: (action, { dispatch, getState }) => {
|
effect: (action, { dispatch }) => {
|
||||||
const log = logger('socketio');
|
const log = logger('socketio');
|
||||||
if (
|
|
||||||
getState().system.canceledSession ===
|
|
||||||
action.payload.data.graph_execution_state_id
|
|
||||||
) {
|
|
||||||
log.trace(
|
|
||||||
action.payload,
|
|
||||||
'Ignored generator progress for canceled session'
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
log.trace(
|
log.trace(action.payload, `Generator progress`);
|
||||||
action.payload,
|
|
||||||
`Generator progress (${action.payload.data.node.type})`
|
|
||||||
);
|
|
||||||
|
|
||||||
// pass along the socket event as an application action
|
|
||||||
dispatch(appSocketGeneratorProgress(action.payload));
|
dispatch(appSocketGeneratorProgress(action.payload));
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
@ -8,10 +8,8 @@ import {
|
|||||||
} from 'features/gallery/store/gallerySlice';
|
} from 'features/gallery/store/gallerySlice';
|
||||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||||
import { CANVAS_OUTPUT } from 'features/nodes/util/graphBuilders/constants';
|
import { CANVAS_OUTPUT } from 'features/nodes/util/graphBuilders/constants';
|
||||||
import { progressImageSet } from 'features/system/store/systemSlice';
|
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
import { isImageOutput } from 'services/api/guards';
|
import { isImageOutput } from 'services/api/guards';
|
||||||
import { sessionCanceled } from 'services/api/thunks/session';
|
|
||||||
import { imagesAdapter } from 'services/api/util';
|
import { imagesAdapter } from 'services/api/util';
|
||||||
import {
|
import {
|
||||||
appSocketInvocationComplete,
|
appSocketInvocationComplete,
|
||||||
@ -31,14 +29,6 @@ export const addInvocationCompleteEventListener = () => {
|
|||||||
{ data: parseify(data) },
|
{ data: parseify(data) },
|
||||||
`Invocation complete (${action.payload.data.node.type})`
|
`Invocation complete (${action.payload.data.node.type})`
|
||||||
);
|
);
|
||||||
const session_id = action.payload.data.graph_execution_state_id;
|
|
||||||
|
|
||||||
const { cancelType, isCancelScheduled } = getState().system;
|
|
||||||
|
|
||||||
// Handle scheduled cancelation
|
|
||||||
if (cancelType === 'scheduled' && isCancelScheduled) {
|
|
||||||
dispatch(sessionCanceled({ session_id }));
|
|
||||||
}
|
|
||||||
|
|
||||||
const { result, node, graph_execution_state_id } = data;
|
const { result, node, graph_execution_state_id } = data;
|
||||||
|
|
||||||
@ -53,8 +43,7 @@ export const addInvocationCompleteEventListener = () => {
|
|||||||
|
|
||||||
// Add canvas images to the staging area
|
// Add canvas images to the staging area
|
||||||
if (
|
if (
|
||||||
graph_execution_state_id ===
|
canvas.sessionIds.includes(graph_execution_state_id) &&
|
||||||
canvas.layerState.stagingArea.sessionId &&
|
|
||||||
[CANVAS_OUTPUT].includes(data.source_node_id)
|
[CANVAS_OUTPUT].includes(data.source_node_id)
|
||||||
) {
|
) {
|
||||||
dispatch(addImageToStagingArea(imageDTO));
|
dispatch(addImageToStagingArea(imageDTO));
|
||||||
@ -87,6 +76,7 @@ export const addInvocationCompleteEventListener = () => {
|
|||||||
categories: IMAGE_CATEGORIES,
|
categories: IMAGE_CATEGORIES,
|
||||||
},
|
},
|
||||||
(draft) => {
|
(draft) => {
|
||||||
|
console.log(draft);
|
||||||
imagesAdapter.addOne(draft, imageDTO);
|
imagesAdapter.addOne(draft, imageDTO);
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -114,8 +104,6 @@ export const addInvocationCompleteEventListener = () => {
|
|||||||
dispatch(imageSelected(imageDTO));
|
dispatch(imageSelected(imageDTO));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatch(progressImageSet(null));
|
|
||||||
}
|
}
|
||||||
// pass along the socket event as an application action
|
// pass along the socket event as an application action
|
||||||
dispatch(appSocketInvocationComplete(action.payload));
|
dispatch(appSocketInvocationComplete(action.payload));
|
||||||
|
@ -8,23 +8,14 @@ import { startAppListening } from '../..';
|
|||||||
export const addInvocationStartedEventListener = () => {
|
export const addInvocationStartedEventListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: socketInvocationStarted,
|
actionCreator: socketInvocationStarted,
|
||||||
effect: (action, { dispatch, getState }) => {
|
effect: (action, { dispatch }) => {
|
||||||
const log = logger('socketio');
|
const log = logger('socketio');
|
||||||
if (
|
|
||||||
getState().system.canceledSession ===
|
|
||||||
action.payload.data.graph_execution_state_id
|
|
||||||
) {
|
|
||||||
log.trace(
|
|
||||||
action.payload,
|
|
||||||
'Ignored invocation started for canceled session'
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
log.debug(
|
log.debug(
|
||||||
action.payload,
|
action.payload,
|
||||||
`Invocation started (${action.payload.data.node.type})`
|
`Invocation started (${action.payload.data.node.type})`
|
||||||
);
|
);
|
||||||
|
|
||||||
dispatch(appSocketInvocationStarted(action.payload));
|
dispatch(appSocketInvocationStarted(action.payload));
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
@ -0,0 +1,56 @@
|
|||||||
|
import { logger } from 'app/logging/logger';
|
||||||
|
import { canvasSessionIdAdded } from 'features/canvas/store/canvasSlice';
|
||||||
|
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
|
||||||
|
import {
|
||||||
|
appSocketQueueItemStatusChanged,
|
||||||
|
socketQueueItemStatusChanged,
|
||||||
|
} from 'services/events/actions';
|
||||||
|
import { startAppListening } from '../..';
|
||||||
|
|
||||||
|
export const addSocketQueueItemStatusChangedEventListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: socketQueueItemStatusChanged,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
const log = logger('socketio');
|
||||||
|
const {
|
||||||
|
queue_item_id: item_id,
|
||||||
|
batch_id,
|
||||||
|
graph_execution_state_id,
|
||||||
|
status,
|
||||||
|
} = action.payload.data;
|
||||||
|
log.debug(
|
||||||
|
action.payload,
|
||||||
|
`Queue item ${item_id} status updated: ${status}`
|
||||||
|
);
|
||||||
|
dispatch(appSocketQueueItemStatusChanged(action.payload));
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => {
|
||||||
|
if (!draft) {
|
||||||
|
console.log('no draft!');
|
||||||
|
}
|
||||||
|
queueItemsAdapter.updateOne(draft, {
|
||||||
|
id: item_id,
|
||||||
|
changes: action.payload.data,
|
||||||
|
});
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
const state = getState();
|
||||||
|
if (state.canvas.batchIds.includes(batch_id)) {
|
||||||
|
dispatch(canvasSessionIdAdded(graph_execution_state_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
queueApi.util.invalidateTags([
|
||||||
|
'CurrentSessionQueueItem',
|
||||||
|
'NextSessionQueueItem',
|
||||||
|
'SessionQueueStatus',
|
||||||
|
{ type: 'SessionQueueItem', id: item_id },
|
||||||
|
{ type: 'SessionQueueItemDTO', id: item_id },
|
||||||
|
{ type: 'BatchStatus', id: batch_id },
|
||||||
|
])
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -1,14 +1,17 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import { appSocketSubscribed, socketSubscribed } from 'services/events/actions';
|
import {
|
||||||
|
appSocketSubscribedSession,
|
||||||
|
socketSubscribedSession,
|
||||||
|
} from 'services/events/actions';
|
||||||
import { startAppListening } from '../..';
|
import { startAppListening } from '../..';
|
||||||
|
|
||||||
export const addSocketSubscribedEventListener = () => {
|
export const addSocketSubscribedEventListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: socketSubscribed,
|
actionCreator: socketSubscribedSession,
|
||||||
effect: (action, { dispatch }) => {
|
effect: (action, { dispatch }) => {
|
||||||
const log = logger('socketio');
|
const log = logger('socketio');
|
||||||
log.debug(action.payload, 'Subscribed');
|
log.debug(action.payload, 'Subscribed');
|
||||||
dispatch(appSocketSubscribed(action.payload));
|
dispatch(appSocketSubscribedSession(action.payload));
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -1,17 +1,17 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import {
|
import {
|
||||||
appSocketUnsubscribed,
|
appSocketUnsubscribedSession,
|
||||||
socketUnsubscribed,
|
socketUnsubscribedSession,
|
||||||
} from 'services/events/actions';
|
} from 'services/events/actions';
|
||||||
import { startAppListening } from '../..';
|
import { startAppListening } from '../..';
|
||||||
|
|
||||||
export const addSocketUnsubscribedEventListener = () => {
|
export const addSocketUnsubscribedEventListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: socketUnsubscribed,
|
actionCreator: socketUnsubscribedSession,
|
||||||
effect: (action, { dispatch }) => {
|
effect: (action, { dispatch }) => {
|
||||||
const log = logger('socketio');
|
const log = logger('socketio');
|
||||||
log.debug(action.payload, 'Unsubscribed');
|
log.debug(action.payload, 'Unsubscribed');
|
||||||
dispatch(appSocketUnsubscribed(action.payload));
|
dispatch(appSocketUnsubscribedSession(action.payload));
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
import { createAction } from '@reduxjs/toolkit';
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
|
import { logger } from 'app/logging/logger';
|
||||||
|
import { parseify } from 'common/util/serialize';
|
||||||
import { buildAdHocUpscaleGraph } from 'features/nodes/util/graphBuilders/buildAdHocUpscaleGraph';
|
import { buildAdHocUpscaleGraph } from 'features/nodes/util/graphBuilders/buildAdHocUpscaleGraph';
|
||||||
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { sessionCreated } from 'services/api/thunks/session';
|
import { t } from 'i18next';
|
||||||
|
import { queueApi } from 'services/api/endpoints/queue';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
export const upscaleRequested = createAction<{ image_name: string }>(
|
export const upscaleRequested = createAction<{ image_name: string }>(
|
||||||
@ -11,7 +14,9 @@ export const upscaleRequested = createAction<{ image_name: string }>(
|
|||||||
export const addUpscaleRequestedListener = () => {
|
export const addUpscaleRequestedListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: upscaleRequested,
|
actionCreator: upscaleRequested,
|
||||||
effect: async (action, { dispatch, getState, take }) => {
|
effect: async (action, { dispatch, getState }) => {
|
||||||
|
const log = logger('session');
|
||||||
|
|
||||||
const { image_name } = action.payload;
|
const { image_name } = action.payload;
|
||||||
const { esrganModelName } = getState().postprocessing;
|
const { esrganModelName } = getState().postprocessing;
|
||||||
|
|
||||||
@ -20,12 +25,31 @@ export const addUpscaleRequestedListener = () => {
|
|||||||
esrganModelName,
|
esrganModelName,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Create a session to run the graph & wait til it's ready to invoke
|
try {
|
||||||
dispatch(sessionCreated({ graph }));
|
const req = dispatch(
|
||||||
|
queueApi.endpoints.enqueueGraph.initiate(
|
||||||
|
{ graph, prepend: true },
|
||||||
|
{
|
||||||
|
fixedCacheKey: 'enqueueGraph',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
await take(sessionCreated.fulfilled.match);
|
const enqueueResult = await req.unwrap();
|
||||||
|
req.reset();
|
||||||
dispatch(sessionReadyToInvoke());
|
log.debug(
|
||||||
|
{ enqueueResult: parseify(enqueueResult) },
|
||||||
|
t('queue.graphQueued')
|
||||||
|
);
|
||||||
|
} catch {
|
||||||
|
log.error({ graph: parseify(graph) }, t('queue.graphFailedToQueue'));
|
||||||
|
dispatch(
|
||||||
|
addToast({
|
||||||
|
title: t('queue.graphFailedToQueue'),
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -1,38 +0,0 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import { userInvoked } from 'app/store/actions';
|
|
||||||
import { parseify } from 'common/util/serialize';
|
|
||||||
import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
|
|
||||||
import { buildLinearImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearImageToImageGraph';
|
|
||||||
import { buildLinearSDXLImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph';
|
|
||||||
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
|
||||||
import { sessionCreated } from 'services/api/thunks/session';
|
|
||||||
import { startAppListening } from '..';
|
|
||||||
|
|
||||||
export const addUserInvokedImageToImageListener = () => {
|
|
||||||
startAppListening({
|
|
||||||
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
|
||||||
userInvoked.match(action) && action.payload === 'img2img',
|
|
||||||
effect: async (action, { getState, dispatch, take }) => {
|
|
||||||
const log = logger('session');
|
|
||||||
const state = getState();
|
|
||||||
const model = state.generation.model;
|
|
||||||
|
|
||||||
let graph;
|
|
||||||
|
|
||||||
if (model && model.base_model === 'sdxl') {
|
|
||||||
graph = buildLinearSDXLImageToImageGraph(state);
|
|
||||||
} else {
|
|
||||||
graph = buildLinearImageToImageGraph(state);
|
|
||||||
}
|
|
||||||
|
|
||||||
dispatch(imageToImageGraphBuilt(graph));
|
|
||||||
log.debug({ graph: parseify(graph) }, 'Image to Image graph built');
|
|
||||||
|
|
||||||
dispatch(sessionCreated({ graph }));
|
|
||||||
|
|
||||||
await take(sessionCreated.fulfilled.match);
|
|
||||||
|
|
||||||
dispatch(sessionReadyToInvoke());
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
@ -1,29 +0,0 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import { userInvoked } from 'app/store/actions';
|
|
||||||
import { parseify } from 'common/util/serialize';
|
|
||||||
import { nodesGraphBuilt } from 'features/nodes/store/actions';
|
|
||||||
import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGraph';
|
|
||||||
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
|
||||||
import { sessionCreated } from 'services/api/thunks/session';
|
|
||||||
import { startAppListening } from '..';
|
|
||||||
|
|
||||||
export const addUserInvokedNodesListener = () => {
|
|
||||||
startAppListening({
|
|
||||||
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
|
||||||
userInvoked.match(action) && action.payload === 'nodes',
|
|
||||||
effect: async (action, { getState, dispatch, take }) => {
|
|
||||||
const log = logger('session');
|
|
||||||
const state = getState();
|
|
||||||
|
|
||||||
const graph = buildNodesGraph(state.nodes);
|
|
||||||
dispatch(nodesGraphBuilt(graph));
|
|
||||||
log.debug({ graph: parseify(graph) }, 'Nodes graph built');
|
|
||||||
|
|
||||||
dispatch(sessionCreated({ graph }));
|
|
||||||
|
|
||||||
await take(sessionCreated.fulfilled.match);
|
|
||||||
|
|
||||||
dispatch(sessionReadyToInvoke());
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
@ -1,39 +0,0 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import { userInvoked } from 'app/store/actions';
|
|
||||||
import { parseify } from 'common/util/serialize';
|
|
||||||
import { textToImageGraphBuilt } from 'features/nodes/store/actions';
|
|
||||||
import { buildLinearSDXLTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph';
|
|
||||||
import { buildLinearTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearTextToImageGraph';
|
|
||||||
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
|
||||||
import { sessionCreated } from 'services/api/thunks/session';
|
|
||||||
import { startAppListening } from '..';
|
|
||||||
|
|
||||||
export const addUserInvokedTextToImageListener = () => {
|
|
||||||
startAppListening({
|
|
||||||
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
|
||||||
userInvoked.match(action) && action.payload === 'txt2img',
|
|
||||||
effect: async (action, { getState, dispatch, take }) => {
|
|
||||||
const log = logger('session');
|
|
||||||
const state = getState();
|
|
||||||
const model = state.generation.model;
|
|
||||||
|
|
||||||
let graph;
|
|
||||||
|
|
||||||
if (model && model.base_model === 'sdxl') {
|
|
||||||
graph = buildLinearSDXLTextToImageGraph(state);
|
|
||||||
} else {
|
|
||||||
graph = buildLinearTextToImageGraph(state);
|
|
||||||
}
|
|
||||||
|
|
||||||
dispatch(textToImageGraphBuilt(graph));
|
|
||||||
|
|
||||||
log.debug({ graph: parseify(graph) }, 'Text to Image graph built');
|
|
||||||
|
|
||||||
dispatch(sessionCreated({ graph }));
|
|
||||||
|
|
||||||
await take(sessionCreated.fulfilled.match);
|
|
||||||
|
|
||||||
dispatch(sessionReadyToInvoke());
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
@ -0,0 +1,54 @@
|
|||||||
|
import { logger } from 'app/logging/logger';
|
||||||
|
import { AppThunkDispatch } from 'app/store/store';
|
||||||
|
import { parseify } from 'common/util/serialize';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { t } from 'i18next';
|
||||||
|
import { queueApi } from 'services/api/endpoints/queue';
|
||||||
|
import { BatchConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
export const enqueueBatch = async (
|
||||||
|
batchConfig: BatchConfig,
|
||||||
|
dispatch: AppThunkDispatch
|
||||||
|
) => {
|
||||||
|
const log = logger('session');
|
||||||
|
const { prepend } = batchConfig;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const req = dispatch(
|
||||||
|
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||||
|
fixedCacheKey: 'enqueueBatch',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
const enqueueResult = await req.unwrap();
|
||||||
|
req.reset();
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
queueApi.endpoints.resumeProcessor.initiate(undefined, {
|
||||||
|
fixedCacheKey: 'resumeProcessor',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
|
||||||
|
dispatch(
|
||||||
|
addToast({
|
||||||
|
title: t('queue.batchQueued'),
|
||||||
|
description: t('queue.batchQueuedDesc', {
|
||||||
|
item_count: enqueueResult.enqueued,
|
||||||
|
direction: prepend ? t('queue.front') : t('queue.back'),
|
||||||
|
}),
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
} catch {
|
||||||
|
log.error(
|
||||||
|
{ batchConfig: parseify(batchConfig) },
|
||||||
|
t('queue.batchFailedToQueue')
|
||||||
|
);
|
||||||
|
dispatch(
|
||||||
|
addToast({
|
||||||
|
title: t('queue.batchFailedToQueue'),
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
5
invokeai/frontend/web/src/app/store/nanostores/store.ts
Normal file
5
invokeai/frontend/web/src/app/store/nanostores/store.ts
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
import { Store } from '@reduxjs/toolkit';
|
||||||
|
import { atom } from 'nanostores';
|
||||||
|
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
export const $store = atom<Store<any> | undefined>();
|
@ -18,6 +18,7 @@ import postprocessingReducer from 'features/parameters/store/postprocessingSlice
|
|||||||
import sdxlReducer from 'features/sdxl/store/sdxlSlice';
|
import sdxlReducer from 'features/sdxl/store/sdxlSlice';
|
||||||
import configReducer from 'features/system/store/configSlice';
|
import configReducer from 'features/system/store/configSlice';
|
||||||
import systemReducer from 'features/system/store/systemSlice';
|
import systemReducer from 'features/system/store/systemSlice';
|
||||||
|
import queueReducer from 'features/queue/store/queueSlice';
|
||||||
import modelmanagerReducer from 'features/ui/components/tabs/ModelManager/store/modelManagerSlice';
|
import modelmanagerReducer from 'features/ui/components/tabs/ModelManager/store/modelManagerSlice';
|
||||||
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
||||||
import uiReducer from 'features/ui/store/uiSlice';
|
import uiReducer from 'features/ui/store/uiSlice';
|
||||||
@ -31,6 +32,7 @@ import { actionSanitizer } from './middleware/devtools/actionSanitizer';
|
|||||||
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
|
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
|
||||||
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
||||||
import { listenerMiddleware } from './middleware/listenerMiddleware';
|
import { listenerMiddleware } from './middleware/listenerMiddleware';
|
||||||
|
import { $store } from './nanostores/store';
|
||||||
|
|
||||||
const allReducers = {
|
const allReducers = {
|
||||||
canvas: canvasReducer,
|
canvas: canvasReducer,
|
||||||
@ -49,6 +51,7 @@ const allReducers = {
|
|||||||
lora: loraReducer,
|
lora: loraReducer,
|
||||||
modelmanager: modelmanagerReducer,
|
modelmanager: modelmanagerReducer,
|
||||||
sdxl: sdxlReducer,
|
sdxl: sdxlReducer,
|
||||||
|
queue: queueReducer,
|
||||||
[api.reducerPath]: api.reducer,
|
[api.reducerPath]: api.reducer,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -121,3 +124,4 @@ export type RootState = ReturnType<typeof store.getState>;
|
|||||||
export type AppThunkDispatch = ThunkDispatch<RootState, any, AnyAction>;
|
export type AppThunkDispatch = ThunkDispatch<RootState, any, AnyAction>;
|
||||||
export type AppDispatch = typeof store.dispatch;
|
export type AppDispatch = typeof store.dispatch;
|
||||||
export const stateSelector = (state: RootState) => state;
|
export const stateSelector = (state: RootState) => state;
|
||||||
|
$store.set(store);
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import {
|
import {
|
||||||
As,
|
As,
|
||||||
ChakraProps,
|
|
||||||
Flex,
|
Flex,
|
||||||
|
FlexProps,
|
||||||
Icon,
|
Icon,
|
||||||
Skeleton,
|
Skeleton,
|
||||||
Spinner,
|
Spinner,
|
||||||
@ -47,15 +47,14 @@ export const IAILoadingImageFallback = (props: Props) => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
type IAINoImageFallbackProps = {
|
type IAINoImageFallbackProps = FlexProps & {
|
||||||
label?: string;
|
label?: string;
|
||||||
icon?: As | null;
|
icon?: As | null;
|
||||||
boxSize?: StyleProps['boxSize'];
|
boxSize?: StyleProps['boxSize'];
|
||||||
sx?: ChakraProps['sx'];
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export const IAINoContentFallback = (props: IAINoImageFallbackProps) => {
|
export const IAINoContentFallback = (props: IAINoImageFallbackProps) => {
|
||||||
const { icon = FaImage, boxSize = 16 } = props;
|
const { icon = FaImage, boxSize = 16, sx, ...rest } = props;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
@ -73,8 +72,9 @@ export const IAINoContentFallback = (props: IAINoImageFallbackProps) => {
|
|||||||
_dark: {
|
_dark: {
|
||||||
color: 'base.500',
|
color: 'base.500',
|
||||||
},
|
},
|
||||||
...props.sx,
|
...sx,
|
||||||
}}
|
}}
|
||||||
|
{...rest}
|
||||||
>
|
>
|
||||||
{icon && <Icon as={icon} boxSize={boxSize} opacity={0.7} />}
|
{icon && <Icon as={icon} boxSize={boxSize} opacity={0.7} />}
|
||||||
{props.label && <Text textAlign="center">{props.label}</Text>}
|
{props.label && <Text textAlign="center">{props.label}</Text>}
|
||||||
|
@ -0,0 +1,29 @@
|
|||||||
|
import { Box, Text } from '@chakra-ui/react';
|
||||||
|
import { forwardRef, memo } from 'react';
|
||||||
|
|
||||||
|
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
|
||||||
|
label: string;
|
||||||
|
value: string;
|
||||||
|
description?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
const IAIMantineSelectItemWithDescription = forwardRef<
|
||||||
|
HTMLDivElement,
|
||||||
|
ItemProps
|
||||||
|
>(({ label, description, ...rest }: ItemProps, ref) => (
|
||||||
|
<Box ref={ref} {...rest}>
|
||||||
|
<Box>
|
||||||
|
<Text fontWeight={600}>{label}</Text>
|
||||||
|
{description && (
|
||||||
|
<Text size="xs" variant="subtext">
|
||||||
|
{description}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
</Box>
|
||||||
|
</Box>
|
||||||
|
));
|
||||||
|
|
||||||
|
IAIMantineSelectItemWithDescription.displayName =
|
||||||
|
'IAIMantineSelectItemWithDescription';
|
||||||
|
|
||||||
|
export default memo(IAIMantineSelectItemWithDescription);
|
@ -4,7 +4,6 @@ import { useAppToaster } from 'app/components/Toaster';
|
|||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import { AnimatePresence, motion } from 'framer-motion';
|
import { AnimatePresence, motion } from 'framer-motion';
|
||||||
import {
|
import {
|
||||||
@ -51,7 +50,6 @@ type ImageUploaderProps = {
|
|||||||
const ImageUploader = (props: ImageUploaderProps) => {
|
const ImageUploader = (props: ImageUploaderProps) => {
|
||||||
const { children } = props;
|
const { children } = props;
|
||||||
const { autoAddBoardId, postUploadAction } = useAppSelector(selector);
|
const { autoAddBoardId, postUploadAction } = useAppSelector(selector);
|
||||||
const isBusy = useAppSelector(selectIsBusy);
|
|
||||||
const toaster = useAppToaster();
|
const toaster = useAppToaster();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const [isHandlingUpload, setIsHandlingUpload] = useState<boolean>(false);
|
const [isHandlingUpload, setIsHandlingUpload] = useState<boolean>(false);
|
||||||
@ -106,6 +104,10 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
|||||||
[t, toaster, fileAcceptedCallback, fileRejectionCallback]
|
[t, toaster, fileAcceptedCallback, fileRejectionCallback]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const onDragOver = useCallback(() => {
|
||||||
|
setIsHandlingUpload(true);
|
||||||
|
}, []);
|
||||||
|
|
||||||
const {
|
const {
|
||||||
getRootProps,
|
getRootProps,
|
||||||
getInputProps,
|
getInputProps,
|
||||||
@ -117,8 +119,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
|||||||
accept: { 'image/png': ['.png'], 'image/jpeg': ['.jpg', '.jpeg', '.png'] },
|
accept: { 'image/png': ['.png'], 'image/jpeg': ['.jpg', '.jpeg', '.png'] },
|
||||||
noClick: true,
|
noClick: true,
|
||||||
onDrop,
|
onDrop,
|
||||||
onDragOver: () => setIsHandlingUpload(true),
|
onDragOver,
|
||||||
disabled: isBusy,
|
|
||||||
multiple: false,
|
multiple: false,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -4,25 +4,22 @@ import { useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import { isInvocationNode } from 'features/nodes/types/types';
|
import { isInvocationNode } from 'features/nodes/types/types';
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
|
import i18n from 'i18next';
|
||||||
import { forEach, map } from 'lodash-es';
|
import { forEach, map } from 'lodash-es';
|
||||||
import { getConnectedEdges } from 'reactflow';
|
import { getConnectedEdges } from 'reactflow';
|
||||||
import i18n from 'i18next';
|
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector, activeTabNameSelector],
|
[stateSelector, activeTabNameSelector],
|
||||||
(state, activeTabName) => {
|
(
|
||||||
const { generation, system, nodes } = state;
|
{ controlNet, generation, system, nodes, dynamicPrompts },
|
||||||
|
activeTabName
|
||||||
|
) => {
|
||||||
const { initialImage, model } = generation;
|
const { initialImage, model } = generation;
|
||||||
|
|
||||||
const { isProcessing, isConnected } = system;
|
const { isConnected } = system;
|
||||||
|
|
||||||
const reasons: string[] = [];
|
const reasons: string[] = [];
|
||||||
|
|
||||||
// Cannot generate if already processing an image
|
|
||||||
if (isProcessing) {
|
|
||||||
reasons.push(i18n.t('parameters.invoke.systemBusy'));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cannot generate if not connected
|
// Cannot generate if not connected
|
||||||
if (!isConnected) {
|
if (!isConnected) {
|
||||||
reasons.push(i18n.t('parameters.invoke.systemDisconnected'));
|
reasons.push(i18n.t('parameters.invoke.systemDisconnected'));
|
||||||
@ -82,12 +79,16 @@ const selector = createSelector(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
if (dynamicPrompts.prompts.length === 0) {
|
||||||
|
reasons.push(i18n.t('parameters.invoke.noPrompts'));
|
||||||
|
}
|
||||||
|
|
||||||
if (!model) {
|
if (!model) {
|
||||||
reasons.push(i18n.t('parameters.invoke.noModelSelected'));
|
reasons.push(i18n.t('parameters.invoke.noModelSelected'));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (state.controlNet.isEnabled) {
|
if (controlNet.isEnabled) {
|
||||||
map(state.controlNet.controlNets).forEach((controlNet, i) => {
|
map(controlNet.controlNets).forEach((controlNet, i) => {
|
||||||
if (!controlNet.isEnabled) {
|
if (!controlNet.isEnabled) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -112,12 +113,12 @@ const selector = createSelector(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return { isReady: !reasons.length, isProcessing, reasons };
|
return { isReady: !reasons.length, reasons };
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
);
|
);
|
||||||
|
|
||||||
export const useIsReadyToInvoke = () => {
|
export const useIsReadyToEnqueue = () => {
|
||||||
const { isReady, isProcessing, reasons } = useAppSelector(selector);
|
const { isReady, reasons } = useAppSelector(selector);
|
||||||
return { isReady, isProcessing, reasons };
|
return { isReady, reasons };
|
||||||
};
|
};
|
28
invokeai/frontend/web/src/common/util/generateSeeds.ts
Normal file
28
invokeai/frontend/web/src/common/util/generateSeeds.ts
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from 'app/constants';
|
||||||
|
import { random } from 'lodash-es';
|
||||||
|
|
||||||
|
export type GenerateSeedsArg = {
|
||||||
|
count: number;
|
||||||
|
start?: number;
|
||||||
|
min?: number;
|
||||||
|
max?: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const generateSeeds = ({
|
||||||
|
count,
|
||||||
|
start,
|
||||||
|
min = NUMPY_RAND_MIN,
|
||||||
|
max = NUMPY_RAND_MAX,
|
||||||
|
}: GenerateSeedsArg) => {
|
||||||
|
const first = start ?? random(min, max);
|
||||||
|
const seeds: number[] = [];
|
||||||
|
for (let i = first; i < first + count; i++) {
|
||||||
|
seeds.push(i % max);
|
||||||
|
}
|
||||||
|
return seeds;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const generateOneSeed = (
|
||||||
|
min: number = NUMPY_RAND_MIN,
|
||||||
|
max: number = NUMPY_RAND_MAX
|
||||||
|
) => random(min, max);
|
@ -153,8 +153,8 @@ const IAICanvas = () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
resizeObserver.observe(containerRef.current);
|
resizeObserver.observe(containerRef.current);
|
||||||
|
const { width, height } = containerRef.current.getBoundingClientRect();
|
||||||
dispatch(canvasResized(containerRef.current.getBoundingClientRect()));
|
dispatch(canvasResized({ width, height }));
|
||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
resizeObserver.disconnect();
|
resizeObserver.disconnect();
|
||||||
|
@ -1,23 +1,24 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
|
||||||
import { ImageConfig } from 'konva/lib/shapes/Image';
|
import { ImageConfig } from 'konva/lib/shapes/Image';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
|
|
||||||
import { memo, useEffect, useState } from 'react';
|
import { memo, useEffect, useState } from 'react';
|
||||||
import { Image as KonvaImage } from 'react-konva';
|
import { Image as KonvaImage } from 'react-konva';
|
||||||
import { canvasSelector } from '../store/canvasSelectors';
|
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[systemSelector, canvasSelector],
|
[stateSelector],
|
||||||
(system, canvas) => {
|
({ system, canvas }) => {
|
||||||
const { progressImage, sessionId } = system;
|
const { denoiseProgress } = system;
|
||||||
const { sessionId: canvasSessionId, boundingBox } =
|
const { boundingBox } = canvas.layerState.stagingArea;
|
||||||
canvas.layerState.stagingArea;
|
const { sessionIds } = canvas;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
boundingBox,
|
boundingBox,
|
||||||
progressImage: sessionId === canvasSessionId ? progressImage : undefined,
|
progressImage:
|
||||||
|
denoiseProgress && sessionIds.includes(denoiseProgress.session_id)
|
||||||
|
? denoiseProgress.progress_image
|
||||||
|
: undefined,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -11,8 +11,9 @@ import {
|
|||||||
setShouldShowStagingImage,
|
setShouldShowStagingImage,
|
||||||
setShouldShowStagingOutline,
|
setShouldShowStagingOutline,
|
||||||
} from 'features/canvas/store/canvasSlice';
|
} from 'features/canvas/store/canvasSlice';
|
||||||
import { isEqual } from 'lodash-es';
|
|
||||||
|
|
||||||
|
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -25,16 +26,15 @@ import {
|
|||||||
FaPlus,
|
FaPlus,
|
||||||
FaSave,
|
FaSave,
|
||||||
} from 'react-icons/fa';
|
} from 'react-icons/fa';
|
||||||
import { stagingAreaImageSaved } from '../store/actions';
|
|
||||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
import { stagingAreaImageSaved } from '../store/actions';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[canvasSelector],
|
[canvasSelector],
|
||||||
(canvas) => {
|
(canvas) => {
|
||||||
const {
|
const {
|
||||||
layerState: {
|
layerState: {
|
||||||
stagingArea: { images, selectedImageIndex, sessionId },
|
stagingArea: { images, selectedImageIndex },
|
||||||
},
|
},
|
||||||
shouldShowStagingOutline,
|
shouldShowStagingOutline,
|
||||||
shouldShowStagingImage,
|
shouldShowStagingImage,
|
||||||
@ -47,14 +47,9 @@ const selector = createSelector(
|
|||||||
isOnLastImage: selectedImageIndex === images.length - 1,
|
isOnLastImage: selectedImageIndex === images.length - 1,
|
||||||
shouldShowStagingImage,
|
shouldShowStagingImage,
|
||||||
shouldShowStagingOutline,
|
shouldShowStagingOutline,
|
||||||
sessionId,
|
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
{
|
defaultSelectorOptions
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
);
|
||||||
|
|
||||||
const IAICanvasStagingAreaToolbar = () => {
|
const IAICanvasStagingAreaToolbar = () => {
|
||||||
@ -64,7 +59,6 @@ const IAICanvasStagingAreaToolbar = () => {
|
|||||||
isOnLastImage,
|
isOnLastImage,
|
||||||
currentStagingAreaImage,
|
currentStagingAreaImage,
|
||||||
shouldShowStagingImage,
|
shouldShowStagingImage,
|
||||||
sessionId,
|
|
||||||
} = useAppSelector(selector);
|
} = useAppSelector(selector);
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
@ -121,8 +115,8 @@ const IAICanvasStagingAreaToolbar = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const handleAccept = useCallback(
|
const handleAccept = useCallback(
|
||||||
() => dispatch(commitStagingAreaImage(sessionId)),
|
() => dispatch(commitStagingAreaImage()),
|
||||||
[dispatch, sessionId]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
const { data: imageDTO } = useGetImageDTOQuery(
|
const { data: imageDTO } = useGetImageDTOQuery(
|
||||||
|
@ -1,24 +1,23 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
import { FaRedo } from 'react-icons/fa';
|
import { FaRedo } from 'react-icons/fa';
|
||||||
|
|
||||||
import { redo } from 'features/canvas/store/canvasSlice';
|
import { redo } from 'features/canvas/store/canvasSlice';
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
|
||||||
|
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
const canvasRedoSelector = createSelector(
|
const canvasRedoSelector = createSelector(
|
||||||
[canvasSelector, activeTabNameSelector, systemSelector],
|
[stateSelector, activeTabNameSelector],
|
||||||
(canvas, activeTabName, system) => {
|
({ canvas }, activeTabName) => {
|
||||||
const { futureLayerStates } = canvas;
|
const { futureLayerStates } = canvas;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
canRedo: futureLayerStates.length > 0 && !system.isProcessing,
|
canRedo: futureLayerStates.length > 0,
|
||||||
activeTabName,
|
activeTabName,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
|
@ -1,14 +1,12 @@
|
|||||||
import { ButtonGroup, Flex } from '@chakra-ui/react';
|
import { ButtonGroup, Flex } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIColorPicker from 'common/components/IAIColorPicker';
|
import IAIColorPicker from 'common/components/IAIColorPicker';
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
import IAIPopover from 'common/components/IAIPopover';
|
import IAIPopover from 'common/components/IAIPopover';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import {
|
import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
|
||||||
canvasSelector,
|
|
||||||
isStagingSelector,
|
|
||||||
} from 'features/canvas/store/canvasSelectors';
|
|
||||||
import {
|
import {
|
||||||
addEraseRect,
|
addEraseRect,
|
||||||
addFillRect,
|
addFillRect,
|
||||||
@ -16,7 +14,6 @@ import {
|
|||||||
setBrushSize,
|
setBrushSize,
|
||||||
setTool,
|
setTool,
|
||||||
} from 'features/canvas/store/canvasSlice';
|
} from 'features/canvas/store/canvasSlice';
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
|
||||||
import { clamp, isEqual } from 'lodash-es';
|
import { clamp, isEqual } from 'lodash-es';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
|
|
||||||
@ -32,15 +29,13 @@ import {
|
|||||||
} from 'react-icons/fa';
|
} from 'react-icons/fa';
|
||||||
|
|
||||||
export const selector = createSelector(
|
export const selector = createSelector(
|
||||||
[canvasSelector, isStagingSelector, systemSelector],
|
[stateSelector, isStagingSelector],
|
||||||
(canvas, isStaging, system) => {
|
({ canvas }, isStaging) => {
|
||||||
const { isProcessing } = system;
|
|
||||||
const { tool, brushColor, brushSize } = canvas;
|
const { tool, brushColor, brushSize } = canvas;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
tool,
|
tool,
|
||||||
isStaging,
|
isStaging,
|
||||||
isProcessing,
|
|
||||||
brushColor,
|
brushColor,
|
||||||
brushSize,
|
brushSize,
|
||||||
};
|
};
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import { Box, ButtonGroup, Flex } from '@chakra-ui/react';
|
import { Box, ButtonGroup, Flex } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
@ -11,10 +12,7 @@ import {
|
|||||||
canvasMerged,
|
canvasMerged,
|
||||||
canvasSavedToGallery,
|
canvasSavedToGallery,
|
||||||
} from 'features/canvas/store/actions';
|
} from 'features/canvas/store/actions';
|
||||||
import {
|
import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
|
||||||
canvasSelector,
|
|
||||||
isStagingSelector,
|
|
||||||
} from 'features/canvas/store/canvasSelectors';
|
|
||||||
import {
|
import {
|
||||||
resetCanvas,
|
resetCanvas,
|
||||||
resetCanvasView,
|
resetCanvasView,
|
||||||
@ -27,9 +25,9 @@ import {
|
|||||||
LAYER_NAMES_DICT,
|
LAYER_NAMES_DICT,
|
||||||
} from 'features/canvas/store/canvasTypes';
|
} from 'features/canvas/store/canvasTypes';
|
||||||
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
|
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
|
||||||
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
|
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
|
import { memo } from 'react';
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import {
|
import {
|
||||||
@ -47,17 +45,14 @@ import IAICanvasRedoButton from './IAICanvasRedoButton';
|
|||||||
import IAICanvasSettingsButtonPopover from './IAICanvasSettingsButtonPopover';
|
import IAICanvasSettingsButtonPopover from './IAICanvasSettingsButtonPopover';
|
||||||
import IAICanvasToolChooserOptions from './IAICanvasToolChooserOptions';
|
import IAICanvasToolChooserOptions from './IAICanvasToolChooserOptions';
|
||||||
import IAICanvasUndoButton from './IAICanvasUndoButton';
|
import IAICanvasUndoButton from './IAICanvasUndoButton';
|
||||||
import { memo } from 'react';
|
|
||||||
|
|
||||||
export const selector = createSelector(
|
export const selector = createSelector(
|
||||||
[systemSelector, canvasSelector, isStagingSelector],
|
[stateSelector, isStagingSelector],
|
||||||
(system, canvas, isStaging) => {
|
({ canvas }, isStaging) => {
|
||||||
const { isProcessing } = system;
|
|
||||||
const { tool, shouldCropToBoundingBoxOnSave, layer, isMaskEnabled } =
|
const { tool, shouldCropToBoundingBoxOnSave, layer, isMaskEnabled } =
|
||||||
canvas;
|
canvas;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
isProcessing,
|
|
||||||
isStaging,
|
isStaging,
|
||||||
isMaskEnabled,
|
isMaskEnabled,
|
||||||
tool,
|
tool,
|
||||||
@ -74,8 +69,7 @@ export const selector = createSelector(
|
|||||||
|
|
||||||
const IAICanvasToolbar = () => {
|
const IAICanvasToolbar = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { isProcessing, isStaging, isMaskEnabled, layer, tool } =
|
const { isStaging, isMaskEnabled, layer, tool } = useAppSelector(selector);
|
||||||
useAppSelector(selector);
|
|
||||||
const canvasBaseLayer = getCanvasBaseLayer();
|
const canvasBaseLayer = getCanvasBaseLayer();
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
@ -118,7 +112,7 @@ const IAICanvasToolbar = () => {
|
|||||||
enabled: () => !isStaging,
|
enabled: () => !isStaging,
|
||||||
preventDefault: true,
|
preventDefault: true,
|
||||||
},
|
},
|
||||||
[canvasBaseLayer, isProcessing]
|
[canvasBaseLayer]
|
||||||
);
|
);
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
@ -130,7 +124,7 @@ const IAICanvasToolbar = () => {
|
|||||||
enabled: () => !isStaging,
|
enabled: () => !isStaging,
|
||||||
preventDefault: true,
|
preventDefault: true,
|
||||||
},
|
},
|
||||||
[canvasBaseLayer, isProcessing]
|
[canvasBaseLayer]
|
||||||
);
|
);
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
@ -142,7 +136,7 @@ const IAICanvasToolbar = () => {
|
|||||||
enabled: () => !isStaging && isClipboardAPIAvailable,
|
enabled: () => !isStaging && isClipboardAPIAvailable,
|
||||||
preventDefault: true,
|
preventDefault: true,
|
||||||
},
|
},
|
||||||
[canvasBaseLayer, isProcessing, isClipboardAPIAvailable]
|
[canvasBaseLayer, isClipboardAPIAvailable]
|
||||||
);
|
);
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
@ -154,7 +148,7 @@ const IAICanvasToolbar = () => {
|
|||||||
enabled: () => !isStaging,
|
enabled: () => !isStaging,
|
||||||
preventDefault: true,
|
preventDefault: true,
|
||||||
},
|
},
|
||||||
[canvasBaseLayer, isProcessing]
|
[canvasBaseLayer]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleSelectMoveTool = () => dispatch(setTool('move'));
|
const handleSelectMoveTool = () => dispatch(setTool('move'));
|
||||||
|
@ -1,24 +1,23 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
import { FaUndo } from 'react-icons/fa';
|
import { FaUndo } from 'react-icons/fa';
|
||||||
|
|
||||||
import { undo } from 'features/canvas/store/canvasSlice';
|
import { undo } from 'features/canvas/store/canvasSlice';
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
|
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
|
||||||
const canvasUndoSelector = createSelector(
|
const canvasUndoSelector = createSelector(
|
||||||
[canvasSelector, activeTabNameSelector, systemSelector],
|
[stateSelector, activeTabNameSelector],
|
||||||
(canvas, activeTabName, system) => {
|
({ canvas }, activeTabName) => {
|
||||||
const { pastLayerStates } = canvas;
|
const { pastLayerStates } = canvas;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
canUndo: pastLayerStates.length > 0 && !system.isProcessing,
|
canUndo: pastLayerStates.length > 0,
|
||||||
activeTabName,
|
activeTabName,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
|
@ -1,16 +1,12 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState, stateSelector } from 'app/store/store';
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
|
||||||
import { CanvasImage, CanvasState, isCanvasBaseImage } from './canvasTypes';
|
import { CanvasImage, CanvasState, isCanvasBaseImage } from './canvasTypes';
|
||||||
|
|
||||||
export const canvasSelector = (state: RootState): CanvasState => state.canvas;
|
export const canvasSelector = (state: RootState): CanvasState => state.canvas;
|
||||||
|
|
||||||
export const isStagingSelector = createSelector(
|
export const isStagingSelector = createSelector(
|
||||||
[canvasSelector, activeTabNameSelector, systemSelector],
|
[stateSelector],
|
||||||
(canvas, activeTabName, system) =>
|
({ canvas }) => canvas.layerState.stagingArea.images.length > 0
|
||||||
canvas.layerState.stagingArea.images.length > 0 ||
|
|
||||||
(activeTabName === 'unifiedCanvas' && system.isProcessing)
|
|
||||||
);
|
);
|
||||||
|
|
||||||
export const initialCanvasImageSelector = (
|
export const initialCanvasImageSelector = (
|
||||||
|
@ -85,6 +85,8 @@ export const initialCanvasState: CanvasState = {
|
|||||||
stageDimensions: { width: 0, height: 0 },
|
stageDimensions: { width: 0, height: 0 },
|
||||||
stageScale: 1,
|
stageScale: 1,
|
||||||
tool: 'brush',
|
tool: 'brush',
|
||||||
|
sessionIds: [],
|
||||||
|
batchIds: [],
|
||||||
};
|
};
|
||||||
|
|
||||||
export const canvasSlice = createSlice({
|
export const canvasSlice = createSlice({
|
||||||
@ -297,18 +299,26 @@ export const canvasSlice = createSlice({
|
|||||||
setIsMoveStageKeyHeld: (state, action: PayloadAction<boolean>) => {
|
setIsMoveStageKeyHeld: (state, action: PayloadAction<boolean>) => {
|
||||||
state.isMoveStageKeyHeld = action.payload;
|
state.isMoveStageKeyHeld = action.payload;
|
||||||
},
|
},
|
||||||
canvasSessionIdChanged: (state, action: PayloadAction<string>) => {
|
canvasBatchIdAdded: (state, action: PayloadAction<string>) => {
|
||||||
state.layerState.stagingArea.sessionId = action.payload;
|
state.batchIds.push(action.payload);
|
||||||
|
},
|
||||||
|
canvasSessionIdAdded: (state, action: PayloadAction<string>) => {
|
||||||
|
state.sessionIds.push(action.payload);
|
||||||
|
},
|
||||||
|
canvasBatchesAndSessionsReset: (state) => {
|
||||||
|
state.sessionIds = [];
|
||||||
|
state.batchIds = [];
|
||||||
},
|
},
|
||||||
stagingAreaInitialized: (
|
stagingAreaInitialized: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<{ sessionId: string; boundingBox: IRect }>
|
action: PayloadAction<{
|
||||||
|
boundingBox: IRect;
|
||||||
|
}>
|
||||||
) => {
|
) => {
|
||||||
const { sessionId, boundingBox } = action.payload;
|
const { boundingBox } = action.payload;
|
||||||
|
|
||||||
state.layerState.stagingArea = {
|
state.layerState.stagingArea = {
|
||||||
boundingBox,
|
boundingBox,
|
||||||
sessionId,
|
|
||||||
images: [],
|
images: [],
|
||||||
selectedImageIndex: -1,
|
selectedImageIndex: -1,
|
||||||
};
|
};
|
||||||
@ -632,10 +642,7 @@ export const canvasSlice = createSlice({
|
|||||||
0
|
0
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
commitStagingAreaImage: (
|
commitStagingAreaImage: (state) => {
|
||||||
state,
|
|
||||||
_action: PayloadAction<string | undefined>
|
|
||||||
) => {
|
|
||||||
if (!state.layerState.stagingArea.images.length) {
|
if (!state.layerState.stagingArea.images.length) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -869,9 +876,11 @@ export const {
|
|||||||
setScaledBoundingBoxDimensions,
|
setScaledBoundingBoxDimensions,
|
||||||
setShouldRestrictStrokesToBox,
|
setShouldRestrictStrokesToBox,
|
||||||
stagingAreaInitialized,
|
stagingAreaInitialized,
|
||||||
canvasSessionIdChanged,
|
|
||||||
setShouldAntialias,
|
setShouldAntialias,
|
||||||
canvasResized,
|
canvasResized,
|
||||||
|
canvasBatchIdAdded,
|
||||||
|
canvasSessionIdAdded,
|
||||||
|
canvasBatchesAndSessionsReset,
|
||||||
} = canvasSlice.actions;
|
} = canvasSlice.actions;
|
||||||
|
|
||||||
export default canvasSlice.reducer;
|
export default canvasSlice.reducer;
|
||||||
|
@ -89,7 +89,6 @@ export type CanvasLayerState = {
|
|||||||
stagingArea: {
|
stagingArea: {
|
||||||
images: CanvasImage[];
|
images: CanvasImage[];
|
||||||
selectedImageIndex: number;
|
selectedImageIndex: number;
|
||||||
sessionId?: string;
|
|
||||||
boundingBox?: IRect;
|
boundingBox?: IRect;
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
@ -166,6 +165,8 @@ export interface CanvasState {
|
|||||||
stageScale: number;
|
stageScale: number;
|
||||||
tool: CanvasTool;
|
tool: CanvasTool;
|
||||||
generationMode?: GenerationMode;
|
generationMode?: GenerationMode;
|
||||||
|
batchIds: string[];
|
||||||
|
sessionIds: string[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export type GenerationMode = 'txt2img' | 'img2img' | 'inpaint' | 'outpaint';
|
export type GenerationMode = 'txt2img' | 'img2img' | 'inpaint' | 'outpaint';
|
||||||
|
@ -3,7 +3,7 @@ import { memo, useCallback } from 'react';
|
|||||||
import { ControlNetConfig } from '../store/controlNetSlice';
|
import { ControlNetConfig } from '../store/controlNetSlice';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { controlNetImageProcessed } from '../store/actions';
|
import { controlNetImageProcessed } from '../store/actions';
|
||||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
import { useIsReadyToEnqueue } from 'common/hooks/useIsReadyToEnqueue';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
controlNet: ControlNetConfig;
|
controlNet: ControlNetConfig;
|
||||||
@ -12,7 +12,7 @@ type Props = {
|
|||||||
const ControlNetPreprocessButton = (props: Props) => {
|
const ControlNetPreprocessButton = (props: Props) => {
|
||||||
const { controlNetId, controlImage } = props.controlNet;
|
const { controlNetId, controlImage } = props.controlNet;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const isReady = useIsReadyToInvoke();
|
const isReady = useIsReadyToEnqueue();
|
||||||
|
|
||||||
const handleProcess = useCallback(() => {
|
const handleProcess = useCallback(() => {
|
||||||
dispatch(
|
dispatch(
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import IAISwitch from 'common/components/IAISwitch';
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
import {
|
import {
|
||||||
ControlNetConfig,
|
ControlNetConfig,
|
||||||
controlNetAutoConfigToggled,
|
controlNetAutoConfigToggled,
|
||||||
} from 'features/controlNet/store/controlNetSlice';
|
} from 'features/controlNet/store/controlNetSlice';
|
||||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
@ -15,7 +14,6 @@ type Props = {
|
|||||||
const ParamControlNetShouldAutoConfig = (props: Props) => {
|
const ParamControlNetShouldAutoConfig = (props: Props) => {
|
||||||
const { controlNetId, isEnabled, shouldAutoConfig } = props.controlNet;
|
const { controlNetId, isEnabled, shouldAutoConfig } = props.controlNet;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const isBusy = useAppSelector(selectIsBusy);
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const handleShouldAutoConfigChanged = useCallback(() => {
|
const handleShouldAutoConfigChanged = useCallback(() => {
|
||||||
@ -28,7 +26,7 @@ const ParamControlNetShouldAutoConfig = (props: Props) => {
|
|||||||
aria-label={t('controlnet.autoConfigure')}
|
aria-label={t('controlnet.autoConfigure')}
|
||||||
isChecked={shouldAutoConfig}
|
isChecked={shouldAutoConfig}
|
||||||
onChange={handleShouldAutoConfigChanged}
|
onChange={handleShouldAutoConfigChanged}
|
||||||
isDisabled={isBusy || !isEnabled}
|
isDisabled={!isEnabled}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -11,11 +11,10 @@ import {
|
|||||||
} from 'features/controlNet/store/controlNetSlice';
|
} from 'features/controlNet/store/controlNetSlice';
|
||||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
|
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
|
||||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
|
||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
type ParamControlNetModelProps = {
|
type ParamControlNetModelProps = {
|
||||||
controlNet: ControlNetConfig;
|
controlNet: ControlNetConfig;
|
||||||
@ -33,7 +32,6 @@ const selector = createSelector(
|
|||||||
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
||||||
const { controlNetId, model: controlNetModel, isEnabled } = props.controlNet;
|
const { controlNetId, model: controlNetModel, isEnabled } = props.controlNet;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const isBusy = useAppSelector(selectIsBusy);
|
|
||||||
|
|
||||||
const { mainModel } = useAppSelector(selector);
|
const { mainModel } = useAppSelector(selector);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
@ -110,7 +108,7 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
|||||||
placeholder={t('controlnet.selectModel')}
|
placeholder={t('controlnet.selectModel')}
|
||||||
value={selectedModel?.id ?? null}
|
value={selectedModel?.id ?? null}
|
||||||
onChange={handleModelChanged}
|
onChange={handleModelChanged}
|
||||||
disabled={isBusy || !isEnabled}
|
disabled={!isEnabled}
|
||||||
tooltip={selectedModel?.description}
|
tooltip={selectedModel?.description}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
|
@ -6,7 +6,6 @@ import IAIMantineSearchableSelect, {
|
|||||||
IAISelectDataType,
|
IAISelectDataType,
|
||||||
} from 'common/components/IAIMantineSearchableSelect';
|
} from 'common/components/IAIMantineSearchableSelect';
|
||||||
import { configSelector } from 'features/system/store/configSelectors';
|
import { configSelector } from 'features/system/store/configSelectors';
|
||||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
|
||||||
import { map } from 'lodash-es';
|
import { map } from 'lodash-es';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { CONTROLNET_PROCESSORS } from '../../store/constants';
|
import { CONTROLNET_PROCESSORS } from '../../store/constants';
|
||||||
@ -56,7 +55,6 @@ const ParamControlNetProcessorSelect = (
|
|||||||
) => {
|
) => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { controlNetId, isEnabled, processorNode } = props.controlNet;
|
const { controlNetId, isEnabled, processorNode } = props.controlNet;
|
||||||
const isBusy = useAppSelector(selectIsBusy);
|
|
||||||
const controlNetProcessors = useAppSelector(selector);
|
const controlNetProcessors = useAppSelector(selector);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
@ -78,7 +76,7 @@ const ParamControlNetProcessorSelect = (
|
|||||||
value={processorNode.type ?? 'canny_image_processor'}
|
value={processorNode.type ?? 'canny_image_processor'}
|
||||||
data={controlNetProcessors}
|
data={controlNetProcessors}
|
||||||
onChange={handleProcessorTypeChanged}
|
onChange={handleProcessorTypeChanged}
|
||||||
disabled={isBusy || !isEnabled}
|
disabled={!isEnabled}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -1,12 +1,10 @@
|
|||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||||
import { RequiredCannyImageProcessorInvocation } from 'features/controlNet/store/types';
|
import { RequiredCannyImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
|
|
||||||
const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor
|
const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor
|
||||||
.default as RequiredCannyImageProcessorInvocation;
|
.default as RequiredCannyImageProcessorInvocation;
|
||||||
@ -20,7 +18,6 @@ type CannyProcessorProps = {
|
|||||||
const CannyProcessor = (props: CannyProcessorProps) => {
|
const CannyProcessor = (props: CannyProcessorProps) => {
|
||||||
const { controlNetId, processorNode, isEnabled } = props;
|
const { controlNetId, processorNode, isEnabled } = props;
|
||||||
const { low_threshold, high_threshold } = processorNode;
|
const { low_threshold, high_threshold } = processorNode;
|
||||||
const isBusy = useAppSelector(selectIsBusy);
|
|
||||||
const processorChanged = useProcessorNodeChanged();
|
const processorChanged = useProcessorNodeChanged();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
@ -53,7 +50,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
|
|||||||
return (
|
return (
|
||||||
<ProcessorWrapper>
|
<ProcessorWrapper>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
isDisabled={isBusy || !isEnabled}
|
isDisabled={!isEnabled}
|
||||||
label={t('controlnet.lowThreshold')}
|
label={t('controlnet.lowThreshold')}
|
||||||
value={low_threshold}
|
value={low_threshold}
|
||||||
onChange={handleLowThresholdChanged}
|
onChange={handleLowThresholdChanged}
|
||||||
@ -65,7 +62,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
|
|||||||
withSliderMarks
|
withSliderMarks
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
isDisabled={isBusy || !isEnabled}
|
isDisabled={!isEnabled}
|
||||||
label={t('controlnet.highThreshold')}
|
label={t('controlnet.highThreshold')}
|
||||||
value={high_threshold}
|
value={high_threshold}
|
||||||
onChange={handleHighThresholdChanged}
|
onChange={handleHighThresholdChanged}
|
||||||
|
@ -2,11 +2,9 @@ import IAISlider from 'common/components/IAISlider';
|
|||||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||||
import { RequiredContentShuffleImageProcessorInvocation } from 'features/controlNet/store/types';
|
import { RequiredContentShuffleImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
|
|
||||||
const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor
|
const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor
|
||||||
.default as RequiredContentShuffleImageProcessorInvocation;
|
.default as RequiredContentShuffleImageProcessorInvocation;
|
||||||
@ -21,7 +19,6 @@ const ContentShuffleProcessor = (props: Props) => {
|
|||||||
const { controlNetId, processorNode, isEnabled } = props;
|
const { controlNetId, processorNode, isEnabled } = props;
|
||||||
const { image_resolution, detect_resolution, w, h, f } = processorNode;
|
const { image_resolution, detect_resolution, w, h, f } = processorNode;
|
||||||
const processorChanged = useProcessorNodeChanged();
|
const processorChanged = useProcessorNodeChanged();
|
||||||
const isBusy = useAppSelector(selectIsBusy);
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const handleDetectResolutionChanged = useCallback(
|
const handleDetectResolutionChanged = useCallback(
|
||||||
@ -101,7 +98,7 @@ const ContentShuffleProcessor = (props: Props) => {
|
|||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={isBusy || !isEnabled}
|
isDisabled={!isEnabled}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label={t('controlnet.imageResolution')}
|
label={t('controlnet.imageResolution')}
|
||||||
@ -113,7 +110,7 @@ const ContentShuffleProcessor = (props: Props) => {
|
|||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={isBusy || !isEnabled}
|
isDisabled={!isEnabled}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label={t('controlnet.w')}
|
label={t('controlnet.w')}
|
||||||
@ -125,7 +122,7 @@ const ContentShuffleProcessor = (props: Props) => {
|
|||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={isBusy || !isEnabled}
|
isDisabled={!isEnabled}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label={t('controlnet.h')}
|
label={t('controlnet.h')}
|
||||||
@ -137,7 +134,7 @@ const ContentShuffleProcessor = (props: Props) => {
|
|||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={isBusy || !isEnabled}
|
isDisabled={!isEnabled}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label={t('controlnet.f')}
|
label={t('controlnet.f')}
|
||||||
@ -149,7 +146,7 @@ const ContentShuffleProcessor = (props: Props) => {
|
|||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={isBusy || !isEnabled}
|
isDisabled={!isEnabled}
|
||||||
/>
|
/>
|
||||||
</ProcessorWrapper>
|
</ProcessorWrapper>
|
||||||
);
|
);
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user