Merge remote-tracking branch 'origin/main' into feat/taesd

# Conflicts:
#	invokeai/backend/model_management/model_probe.py
This commit is contained in:
Kevin Turner
2023-09-20 10:46:55 -07:00
381 changed files with 14651 additions and 4930 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import sqlite3
from logging import Logger
from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage
@ -9,7 +10,10 @@ from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.session_processor.session_processor_default import DefaultSessionProcessor
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
from invokeai.app.services.urls import LocalUrlService
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
@ -25,6 +29,7 @@ from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsSto
from ..services.model_manager_service import ModelManagerService
from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage
from ..services.thread import lock
from .events import FastAPIEventService
@ -63,22 +68,32 @@ class ApiDependencies:
output_folder = config.output_path
# TODO: build a file/path manager?
db_path = config.db_path
db_path.parent.mkdir(parents=True, exist_ok=True)
db_location = str(db_path)
if config.use_memory_db:
db_location = ":memory:"
else:
db_path = config.db_path
db_path.parent.mkdir(parents=True, exist_ok=True)
db_location = str(db_path)
logger.info(f"Using database at {db_location}")
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
if config.log_sql:
db_conn.set_trace_callback(print)
db_conn.execute("PRAGMA foreign_keys = ON;")
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
conn=db_conn, table_name="graph_executions", lock=lock
)
urls = LocalUrlService()
image_record_storage = SqliteImageRecordStorage(db_location)
image_record_storage = SqliteImageRecordStorage(conn=db_conn, lock=lock)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService()
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
board_record_storage = SqliteBoardRecordStorage(db_location)
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
board_record_storage = SqliteBoardRecordStorage(conn=db_conn, lock=lock)
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn, lock=lock)
boards = BoardService(
services=BoardServiceDependencies(
@ -120,18 +135,29 @@ class ApiDependencies:
boards=boards,
board_images=board_images,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, lock=lock, table_name="graphs"),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
configuration=config,
performance_statistics=InvocationStatsService(graph_execution_manager),
logger=logger,
session_queue=SqliteSessionQueue(conn=db_conn, lock=lock),
session_processor=DefaultSessionProcessor(),
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size),
)
create_system_graphs(services.graph_library)
ApiDependencies.invoker = Invoker(services)
try:
lock.acquire()
db_conn.execute("VACUUM;")
db_conn.commit()
logger.info("Cleaned database")
finally:
lock.release()
@staticmethod
def shutdown():
if ApiDependencies.invoker:

View File

@ -103,3 +103,13 @@ async def set_log_level(
"""Sets the log verbosity level"""
ApiDependencies.invoker.services.logger.setLevel(level)
return LogLevel(ApiDependencies.invoker.services.logger.level)
@app_router.delete(
"/invocation_cache",
operation_id="clear_invocation_cache",
responses={200: {"description": "The operation was successful"}},
)
async def clear_invocation_cache() -> None:
"""Clears the invocation cache"""
ApiDependencies.invoker.services.invocation_cache.clear()

View File

@ -0,0 +1,247 @@
from typing import Optional
from fastapi import Body, Path, Query
from fastapi.routing import APIRouter
from pydantic import BaseModel
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
Batch,
BatchStatus,
CancelByBatchIDsResult,
ClearResult,
EnqueueBatchResult,
EnqueueGraphResult,
PruneResult,
SessionQueueItem,
SessionQueueItemDTO,
SessionQueueStatus,
)
from invokeai.app.services.shared.models import CursorPaginatedResults
from ...services.graph import Graph
from ..dependencies import ApiDependencies
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])
class SessionQueueAndProcessorStatus(BaseModel):
"""The overall status of session queue and processor"""
queue: SessionQueueStatus
processor: SessionProcessorStatus
@session_queue_router.post(
"/{queue_id}/enqueue_graph",
operation_id="enqueue_graph",
responses={
201: {"model": EnqueueGraphResult},
},
)
async def enqueue_graph(
queue_id: str = Path(description="The queue id to perform this operation on"),
graph: Graph = Body(description="The graph to enqueue"),
prepend: bool = Body(default=False, description="Whether or not to prepend this batch in the queue"),
) -> EnqueueGraphResult:
"""Enqueues a graph for single execution."""
return ApiDependencies.invoker.services.session_queue.enqueue_graph(queue_id=queue_id, graph=graph, prepend=prepend)
@session_queue_router.post(
"/{queue_id}/enqueue_batch",
operation_id="enqueue_batch",
responses={
201: {"model": EnqueueBatchResult},
},
)
async def enqueue_batch(
queue_id: str = Path(description="The queue id to perform this operation on"),
batch: Batch = Body(description="Batch to process"),
prepend: bool = Body(default=False, description="Whether or not to prepend this batch in the queue"),
) -> EnqueueBatchResult:
"""Processes a batch and enqueues the output graphs for execution."""
return ApiDependencies.invoker.services.session_queue.enqueue_batch(queue_id=queue_id, batch=batch, prepend=prepend)
@session_queue_router.get(
"/{queue_id}/list",
operation_id="list_queue_items",
responses={
200: {"model": CursorPaginatedResults[SessionQueueItemDTO]},
},
)
async def list_queue_items(
queue_id: str = Path(description="The queue id to perform this operation on"),
limit: int = Query(default=50, description="The number of items to fetch"),
status: Optional[QUEUE_ITEM_STATUS] = Query(default=None, description="The status of items to fetch"),
cursor: Optional[int] = Query(default=None, description="The pagination cursor"),
priority: int = Query(default=0, description="The pagination cursor priority"),
) -> CursorPaginatedResults[SessionQueueItemDTO]:
"""Gets all queue items (without graphs)"""
return ApiDependencies.invoker.services.session_queue.list_queue_items(
queue_id=queue_id, limit=limit, status=status, cursor=cursor, priority=priority
)
@session_queue_router.put(
"/{queue_id}/processor/resume",
operation_id="resume",
responses={200: {"model": SessionProcessorStatus}},
)
async def resume(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionProcessorStatus:
"""Resumes session processor"""
return ApiDependencies.invoker.services.session_processor.resume()
@session_queue_router.put(
"/{queue_id}/processor/pause",
operation_id="pause",
responses={200: {"model": SessionProcessorStatus}},
)
async def Pause(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionProcessorStatus:
"""Pauses session processor"""
return ApiDependencies.invoker.services.session_processor.pause()
@session_queue_router.put(
"/{queue_id}/cancel_by_batch_ids",
operation_id="cancel_by_batch_ids",
responses={200: {"model": CancelByBatchIDsResult}},
)
async def cancel_by_batch_ids(
queue_id: str = Path(description="The queue id to perform this operation on"),
batch_ids: list[str] = Body(description="The list of batch_ids to cancel all queue items for", embed=True),
) -> CancelByBatchIDsResult:
"""Immediately cancels all queue items from the given batch ids"""
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
@session_queue_router.put(
"/{queue_id}/clear",
operation_id="clear",
responses={
200: {"model": ClearResult},
},
)
async def clear(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> ClearResult:
"""Clears the queue entirely, immediately canceling the currently-executing session"""
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
if queue_item is not None:
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
return clear_result
@session_queue_router.put(
"/{queue_id}/prune",
operation_id="prune",
responses={
200: {"model": PruneResult},
},
)
async def prune(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> PruneResult:
"""Prunes all completed or errored queue items"""
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
@session_queue_router.get(
"/{queue_id}/current",
operation_id="get_current_queue_item",
responses={
200: {"model": Optional[SessionQueueItem]},
},
)
async def get_current_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> Optional[SessionQueueItem]:
"""Gets the currently execution queue item"""
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
@session_queue_router.get(
"/{queue_id}/next",
operation_id="get_next_queue_item",
responses={
200: {"model": Optional[SessionQueueItem]},
},
)
async def get_next_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> Optional[SessionQueueItem]:
"""Gets the next queue item, without executing it"""
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
@session_queue_router.get(
"/{queue_id}/status",
operation_id="get_queue_status",
responses={
200: {"model": SessionQueueAndProcessorStatus},
},
)
async def get_queue_status(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionQueueAndProcessorStatus:
"""Gets the status of the session queue"""
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
processor = ApiDependencies.invoker.services.session_processor.get_status()
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
@session_queue_router.get(
"/{queue_id}/b/{batch_id}/status",
operation_id="get_batch_status",
responses={
200: {"model": BatchStatus},
},
)
async def get_batch_status(
queue_id: str = Path(description="The queue id to perform this operation on"),
batch_id: str = Path(description="The batch to get the status of"),
) -> BatchStatus:
"""Gets the status of the session queue"""
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
@session_queue_router.get(
"/{queue_id}/i/{item_id}",
operation_id="get_queue_item",
responses={
200: {"model": SessionQueueItem},
},
)
async def get_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),
item_id: int = Path(description="The queue item to get"),
) -> SessionQueueItem:
"""Gets a queue item"""
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
@session_queue_router.put(
"/{queue_id}/i/{item_id}/cancel",
operation_id="cancel_queue_item",
responses={
200: {"model": SessionQueueItem},
},
)
async def cancel_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),
item_id: int = Path(description="The queue item to cancel"),
) -> SessionQueueItem:
"""Deletes a queue item"""
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)

View File

@ -23,12 +23,14 @@ session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
200: {"model": GraphExecutionState},
400: {"description": "Invalid json"},
},
deprecated=True,
)
async def create_session(
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with")
queue_id: str = Query(default="", description="The id of the queue to associate the session with"),
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
) -> GraphExecutionState:
"""Creates a new session, optionally initializing it with an invocation graph"""
session = ApiDependencies.invoker.create_execution_state(graph)
session = ApiDependencies.invoker.create_execution_state(queue_id=queue_id, graph=graph)
return session
@ -36,6 +38,7 @@ async def create_session(
"/",
operation_id="list_sessions",
responses={200: {"model": PaginatedResults[GraphExecutionState]}},
deprecated=True,
)
async def list_sessions(
page: int = Query(default=0, description="The page of results to get"),
@ -57,6 +60,7 @@ async def list_sessions(
200: {"model": GraphExecutionState},
404: {"description": "Session not found"},
},
deprecated=True,
)
async def get_session(
session_id: str = Path(description="The id of the session to get"),
@ -77,6 +81,7 @@ async def get_session(
400: {"description": "Invalid node or link"},
404: {"description": "Session not found"},
},
deprecated=True,
)
async def add_node(
session_id: str = Path(description="The id of the session"),
@ -109,6 +114,7 @@ async def add_node(
400: {"description": "Invalid node or link"},
404: {"description": "Session not found"},
},
deprecated=True,
)
async def update_node(
session_id: str = Path(description="The id of the session"),
@ -142,6 +148,7 @@ async def update_node(
400: {"description": "Invalid node or link"},
404: {"description": "Session not found"},
},
deprecated=True,
)
async def delete_node(
session_id: str = Path(description="The id of the session"),
@ -172,6 +179,7 @@ async def delete_node(
400: {"description": "Invalid node or link"},
404: {"description": "Session not found"},
},
deprecated=True,
)
async def add_edge(
session_id: str = Path(description="The id of the session"),
@ -203,6 +211,7 @@ async def add_edge(
400: {"description": "Invalid node or link"},
404: {"description": "Session not found"},
},
deprecated=True,
)
async def delete_edge(
session_id: str = Path(description="The id of the session"),
@ -241,8 +250,10 @@ async def delete_edge(
400: {"description": "The session has no invocations ready to invoke"},
404: {"description": "Session not found"},
},
deprecated=True,
)
async def invoke_session(
queue_id: str = Query(description="The id of the queue to associate the session with"),
session_id: str = Path(description="The id of the session to invoke"),
all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"),
) -> Response:
@ -254,7 +265,7 @@ async def invoke_session(
if session.is_complete():
raise HTTPException(status_code=400)
ApiDependencies.invoker.invoke(session, invoke_all=all)
ApiDependencies.invoker.invoke(queue_id, session, invoke_all=all)
return Response(status_code=202)
@ -262,6 +273,7 @@ async def invoke_session(
"/{session_id}/invoke",
operation_id="cancel_session_invoke",
responses={202: {"description": "The invocation is canceled"}},
deprecated=True,
)
async def cancel_session_invoke(
session_id: str = Path(description="The id of the session to cancel"),

View File

@ -0,0 +1,41 @@
from typing import Optional
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
from fastapi import Body
from fastapi.routing import APIRouter
from pydantic import BaseModel
from pyparsing import ParseException
utilities_router = APIRouter(prefix="/v1/utilities", tags=["utilities"])
class DynamicPromptsResponse(BaseModel):
prompts: list[str]
error: Optional[str] = None
@utilities_router.post(
"/dynamicprompts",
operation_id="parse_dynamicprompts",
responses={
200: {"model": DynamicPromptsResponse},
},
)
async def parse_dynamicprompts(
prompt: str = Body(description="The prompt to parse with dynamicprompts"),
max_prompts: int = Body(default=1000, description="The max number of prompts to generate"),
combinatorial: bool = Body(default=True, description="Whether to use the combinatorial generator"),
) -> DynamicPromptsResponse:
"""Creates a batch process"""
try:
error: Optional[str] = None
if combinatorial:
generator = CombinatorialPromptGenerator()
prompts = generator.generate(prompt, max_prompts=max_prompts)
else:
generator = RandomPromptGenerator()
prompts = generator.generate(prompt, num_images=max_prompts)
except ParseException as e:
prompts = [prompt]
error = str(e)
return DynamicPromptsResponse(prompts=prompts if prompts else [""], error=error)

View File

@ -13,24 +13,22 @@ class SocketIO:
def __init__(self, app: FastAPI):
self.__sio = SocketManager(app=app)
self.__sio.on("subscribe", handler=self._handle_sub)
self.__sio.on("unsubscribe", handler=self._handle_unsub)
local_handler.register(event_name=EventServiceBase.session_event, _func=self._handle_session_event)
self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
async def _handle_session_event(self, event: Event):
async def _handle_queue_event(self, event: Event):
await self.__sio.emit(
event=event[1]["event"],
data=event[1]["data"],
room=event[1]["data"]["graph_execution_state_id"],
room=event[1]["data"]["queue_id"],
)
async def _handle_sub(self, sid, data, *args, **kwargs):
if "session" in data:
self.__sio.enter_room(sid, data["session"])
async def _handle_sub_queue(self, sid, data, *args, **kwargs):
if "queue_id" in data:
self.__sio.enter_room(sid, data["queue_id"])
# @app.sio.on('unsubscribe')
async def _handle_unsub(self, sid, data, *args, **kwargs):
if "session" in data:
self.__sio.leave_room(sid, data["session"])
async def _handle_unsub_queue(self, sid, data, *args, **kwargs):
if "queue_id" in data:
self.__sio.enter_room(sid, data["queue_id"])

View File

@ -1,4 +1,3 @@
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from .services.config import InvokeAIAppConfig
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
@ -33,7 +32,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies
from .api.routers import app_info, board_images, boards, images, models, sessions
from .api.routers import app_info, board_images, boards, images, models, session_queue, sessions, utilities
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
@ -92,6 +91,8 @@ async def shutdown_event():
app.include_router(sessions.session_router, prefix="/api")
app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
@ -102,6 +103,8 @@ app.include_router(board_images.board_images_router, prefix="/api")
app.include_router(app_info.app_router, prefix="/api")
app.include_router(session_queue.session_queue_router, prefix="/api")
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?

View File

@ -1,4 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from .services.config import InvokeAIAppConfig
@ -12,6 +14,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
import argparse
import re
import shlex
import sqlite3
import sys
import time
from typing import Optional, Union, get_type_hints
@ -249,19 +252,18 @@ def invoke_cli():
db_location = config.db_path
db_location.parent.mkdir(parents=True, exist_ok=True)
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
logger.info(f'InvokeAI database location is "{db_location}"')
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
urls = LocalUrlService()
image_record_storage = SqliteImageRecordStorage(db_location)
image_record_storage = SqliteImageRecordStorage(conn=db_conn)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService()
board_record_storage = SqliteBoardRecordStorage(db_location)
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
board_record_storage = SqliteBoardRecordStorage(conn=db_conn)
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn)
boards = BoardService(
services=BoardServiceDependencies(
@ -303,12 +305,13 @@ def invoke_cli():
boards=boards,
board_images=board_images,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
performance_statistics=InvocationStatsService(graph_execution_manager),
logger=logger,
configuration=config,
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size),
)
system_graphs = create_system_graphs(services.graph_library)

View File

@ -67,6 +67,7 @@ class FieldDescriptions:
width = "Width of output (px)"
height = "Height of output (px)"
control = "ControlNet(s) to apply"
ip_adapter = "IP-Adapter to apply"
denoised_latents = "Denoised latents tensor"
latents = "Latents tensor"
strength = "Strength of denoising (proportional to steps)"
@ -155,6 +156,7 @@ class UIType(str, Enum):
VaeModel = "VaeModelField"
LoRAModel = "LoRAModelField"
ControlNetModel = "ControlNetModelField"
IPAdapterModel = "IPAdapterModelField"
UNet = "UNetField"
Vae = "VaeField"
CLIP = "ClipField"
@ -417,12 +419,27 @@ class UIConfigBase(BaseModel):
class InvocationContext:
"""Initialized and provided to on execution of invocations."""
services: InvocationServices
graph_execution_state_id: str
queue_id: str
queue_item_id: int
queue_batch_id: str
def __init__(self, services: InvocationServices, graph_execution_state_id: str):
def __init__(
self,
services: InvocationServices,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
):
self.services = services
self.graph_execution_state_id = graph_execution_state_id
self.queue_id = queue_id
self.queue_item_id = queue_item_id
self.queue_batch_id = queue_batch_id
class BaseInvocationOutput(BaseModel):
@ -520,6 +537,9 @@ class BaseInvocation(ABC, BaseModel):
return signature(cls.invoke).return_annotation
class Config:
validate_assignment = True
validate_all = True
@staticmethod
def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
uiconfig = getattr(model_class, "UIConfig", None)
@ -568,7 +588,29 @@ class BaseInvocation(ABC, BaseModel):
raise RequiredConnectionException(self.__fields__["type"].default, field_name)
elif _input == Input.Any:
raise MissingInputException(self.__fields__["type"].default, field_name)
return self.invoke(context)
# skip node cache codepath if it's disabled
if context.services.configuration.node_cache_size == 0:
return self.invoke(context)
output: BaseInvocationOutput
if self.use_cache:
key = context.services.invocation_cache.create_key(self)
cached_value = context.services.invocation_cache.get(key)
if cached_value is None:
context.services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}')
output = self.invoke(context)
context.services.invocation_cache.save(key, output)
return output
else:
context.services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}')
return cached_value
else:
context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
return self.invoke(context)
def get_type(self) -> str:
return self.__fields__["type"].default
id: str = Field(
description="The id of this instance of an invocation. Must be unique among all instances of invocations."
@ -581,6 +623,7 @@ class BaseInvocation(ABC, BaseModel):
description="The workflow to save with the image",
ui_type=UIType.WorkflowField,
)
use_cache: bool = InputField(default=True, description="Whether or not to use the cache")
@validator("workflow", pre=True)
def validate_workflow_is_json(cls, v):
@ -604,6 +647,7 @@ def invocation(
tags: Optional[list[str]] = None,
category: Optional[str] = None,
version: Optional[str] = None,
use_cache: Optional[bool] = True,
) -> Callable[[Type[GenericBaseInvocation]], Type[GenericBaseInvocation]]:
"""
Adds metadata to an invocation.
@ -636,6 +680,8 @@ def invocation(
except ValueError as e:
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
cls.UIConfig.version = version
if use_cache is not None:
cls.__fields__["use_cache"].default = use_cache
# Add the invocation type to the pydantic model of the invocation
invocation_type_annotation = Literal[invocation_type] # type: ignore

View File

@ -56,6 +56,7 @@ class RangeOfSizeInvocation(BaseInvocation):
tags=["range", "integer", "random", "collection"],
category="collections",
version="1.0.0",
use_cache=False,
)
class RandomRangeInvocation(BaseInvocation):
"""Creates a collection of random numbers"""

View File

@ -7,14 +7,14 @@ from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import (
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ExtraConditioningInfo,
SDXLConditioningInfo,
)
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import ModelNotFoundException, ModelType
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.util.devices import torch_dtype
from .baseinvocation import (
BaseInvocation,
@ -99,14 +99,15 @@ class CompelInvocation(BaseInvocation):
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
with ModelPatcher.apply_lora_text_encoder(
text_encoder_info.context.model, _lora_loader()
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
), ModelPatcher.apply_clip_skip(
text_encoder_info.context.model, self.clip.skipped_layers
), text_encoder_info as text_encoder:
with (
ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
),
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
text_encoder_info as text_encoder,
):
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
@ -122,7 +123,7 @@ class CompelInvocation(BaseInvocation):
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
ec = ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None),
)
@ -213,14 +214,15 @@ class SDXLPromptInvocationBase:
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
with ModelPatcher.apply_lora(
text_encoder_info.context.model, _lora_loader(), lora_prefix
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
), ModelPatcher.apply_clip_skip(
text_encoder_info.context.model, clip_field.skipped_layers
), text_encoder_info as text_encoder:
with (
ModelPatcher.apply_lora(text_encoder_info.context.model, _lora_loader(), lora_prefix),
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
),
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
text_encoder_info as text_encoder,
):
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
@ -244,7 +246,7 @@ class SDXLPromptInvocationBase:
else:
c_pooled = None
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
ec = ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None),
)
@ -436,9 +438,11 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
text_fragments = [
x.text
if type(x) is Fragment
else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x))
(
x.text
if type(x) is Fragment
else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x))
)
for x in parsed_prompt.children
]
text = " ".join(text_fragments)

View File

@ -965,3 +965,42 @@ class ImageChannelMultiplyInvocation(BaseInvocation):
width=image_dto.width,
height=image_dto.height,
)
@invocation(
"save_image",
title="Save Image",
tags=["primitives", "image"],
category="primitives",
version="1.0.0",
use_cache=False,
)
class SaveImageInvocation(BaseInvocation):
"""Saves an image. Unlike an image primitive, this invocation stores a copy of the image."""
image: ImageField = InputField(description="The image to load")
metadata: CoreMetadata = InputField(
default=None,
description=FieldDescriptions.core_metadata,
ui_hidden=True,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
image_dto = context.services.images.create(
image=image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
workflow=self.workflow,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -0,0 +1,105 @@
import os
from builtins import float
from typing import List, Union
from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UIType,
invocation,
invocation_output,
)
from invokeai.app.invocations.primitives import ImageField
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
class CLIPVisionModelField(BaseModel):
model_name: str = Field(description="Name of the CLIP Vision image encoder model")
base_model: BaseModelType = Field(description="Base model (usually 'Any')")
class IPAdapterField(BaseModel):
image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.")
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
# weight: float = Field(default=1.0, ge=0, description="The weight of the IP-Adapter.")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
)
@invocation_output("ip_adapter_output")
class IPAdapterOutput(BaseInvocationOutput):
# Outputs
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.0.0")
class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes."""
# Inputs
image: ImageField = InputField(description="The IP-Adapter image prompt.")
ip_adapter_model: IPAdapterModelField = InputField(
description="The IP-Adapter model.",
title="IP-Adapter Model",
input=Input.Direct,
)
# weight: float = InputField(default=1.0, description="The weight of the IP-Adapter.", ui_type=UIType.Float)
weight: Union[float, List[float]] = InputField(
default=1, ge=0, description="The weight given to the IP-Adapter", ui_type=UIType.Float, title="Weight"
)
begin_step_percent: float = InputField(
default=0, ge=-1, le=2, description="When the IP-Adapter is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
)
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.services.model_manager.model_info(
self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter
)
# HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model
# directly, and 2) we are reading from disk every time this invocation is called without caching the result.
# A better solution would be to store the image encoder model reference in the IP-Adapter model info, but this
# is currently messy due to differences between how the model info is generated when installing a model from
# disk vs. downloading the model.
image_encoder_model_id = get_ip_adapter_image_encoder_model_id(
os.path.join(context.services.configuration.get_config().models_path, ip_adapter_info["path"])
)
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
image_encoder_model = CLIPVisionModelField(
model_name=image_encoder_model_name,
base_model=BaseModelType.Any,
)
return IPAdapterOutput(
ip_adapter=IPAdapterField(
image=self.image,
ip_adapter_model=self.ip_adapter_model,
image_encoder_model=image_encoder_model,
weight=self.weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
),
)

View File

@ -10,6 +10,7 @@ import torch
import torchvision.transforms as T
from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import UNet2DConditionModel
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
@ -21,6 +22,7 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import validator
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.ip_adapter import IPAdapterField
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.invocations.primitives import (
DenoiseMaskField,
@ -33,15 +35,17 @@ from invokeai.app.invocations.primitives import (
)
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import BaseModelType
from ...backend.model_management.seamless import set_seamless
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData,
ControlNetData,
IPAdapterData,
StableDiffusionGeneratorPipeline,
image_resized_to_grid_as_tensor,
)
@ -70,7 +74,6 @@ if choose_torch_device() == torch.device("mps"):
DEFAULT_PRECISION = choose_precision(choose_torch_device())
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
@ -193,7 +196,7 @@ def get_scheduler(
title="Denoise Latents",
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents",
version="1.0.0",
version="1.1.0",
)
class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images"""
@ -221,9 +224,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
input=Input.Connection,
ui_order=5,
)
ip_adapter: Optional[IPAdapterField] = InputField(
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection, ui_order=6
)
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=6
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=7
)
@validator("cfg_scale")
@ -325,8 +331,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
def prep_control_data(
self,
context: InvocationContext,
# really only need model for dtype and device
model: StableDiffusionGeneratorPipeline,
control_input: Union[ControlField, List[ControlField]],
latents_shape: List[int],
exit_stack: ExitStack,
@ -346,57 +350,107 @@ class DenoiseLatentsInvocation(BaseInvocation):
else:
control_list = None
if control_list is None:
control_data = None
# from above handling, any control that is not None should now be of type list[ControlField]
else:
# FIXME: add checks to skip entry if model or image is None
# and if weight is None, populate with default 1.0?
control_data = []
control_models = []
for control_info in control_list:
control_model = exit_stack.enter_context(
context.services.model_manager.get_model(
model_name=control_info.control_model.model_name,
model_type=ModelType.ControlNet,
base_model=control_info.control_model.base_model,
context=context,
)
)
return None
# After above handling, any control that is not None should now be of type list[ControlField].
control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance?
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
control_image = prepare_control_image(
image=input_image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=control_width_resize,
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
dtype=control_model.dtype,
control_mode=control_info.control_mode,
resize_mode=control_info.resize_mode,
# FIXME: add checks to skip entry if model or image is None
# and if weight is None, populate with default 1.0?
controlnet_data = []
for control_info in control_list:
control_model = exit_stack.enter_context(
context.services.model_manager.get_model(
model_name=control_info.control_model.model_name,
model_type=ModelType.ControlNet,
base_model=control_info.control_model.base_model,
context=context,
)
control_item = ControlNetData(
model=control_model,
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,
# any resizing needed should currently be happening in prepare_control_image(),
# but adding resize_mode to ControlNetData in case needed in the future
resize_mode=control_info.resize_mode,
)
control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data
)
# control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance?
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
control_image = prepare_control_image(
image=input_image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=control_width_resize,
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
dtype=control_model.dtype,
control_mode=control_info.control_mode,
resize_mode=control_info.resize_mode,
)
control_item = ControlNetData(
model=control_model, # model object
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,
# any resizing needed should currently be happening in prepare_control_image(),
# but adding resize_mode to ControlNetData in case needed in the future
resize_mode=control_info.resize_mode,
)
controlnet_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return controlnet_data
def prep_ip_adapter_data(
self,
context: InvocationContext,
ip_adapter: Optional[IPAdapterField],
conditioning_data: ConditioningData,
unet: UNet2DConditionModel,
exit_stack: ExitStack,
) -> Optional[IPAdapterData]:
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
to the `conditioning_data` (in-place).
"""
if ip_adapter is None:
return None
image_encoder_model_info = context.services.model_manager.get_model(
model_name=ip_adapter.image_encoder_model.model_name,
model_type=ModelType.CLIPVision,
base_model=ip_adapter.image_encoder_model.base_model,
context=context,
)
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.services.model_manager.get_model(
model_name=ip_adapter.ip_adapter_model.model_name,
model_type=ModelType.IPAdapter,
base_model=ip_adapter.ip_adapter_model.base_model,
context=context,
)
)
input_image = context.services.images.get_pil_image(ip_adapter.image.image_name)
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
with image_encoder_model_info as image_encoder_model:
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
input_image, image_encoder_model
)
conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo(
image_prompt_embeds, uncond_image_prompt_embeds
)
return IPAdapterData(
ip_adapter_model=ip_adapter_model,
weight=ip_adapter.weight,
begin_step_percent=ip_adapter.begin_step_percent,
end_step_percent=ip_adapter.end_step_percent,
)
# original idea by https://github.com/AmericanPresidentJimmyCarter
# TODO: research more for second order schedulers timesteps
@ -490,9 +544,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
**self.unet.unet.dict(),
context=context,
)
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
unet_info.context.model, _lora_loader()
), set_seamless(unet_info.context.model, self.unet.seamless_axes), unet_info as unet:
with (
ExitStack() as exit_stack,
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),
set_seamless(unet_info.context.model, self.unet.seamless_axes),
unet_info as unet,
):
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
@ -511,8 +568,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
control_data = self.prep_control_data(
model=pipeline,
controlnet_data = self.prep_control_data(
context=context,
control_input=self.control,
latents_shape=latents.shape,
@ -521,6 +577,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
exit_stack=exit_stack,
)
ip_adapter_data = self.prep_ip_adapter_data(
context=context,
ip_adapter=self.ip_adapter,
conditioning_data=conditioning_data,
unet=unet,
exit_stack=exit_stack,
)
num_inference_steps, timesteps, init_timestep = self.init_scheduler(
scheduler,
device=unet.device,
@ -539,7 +603,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
masked_latents=masked_latents,
num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
control_data=controlnet_data, # list[ControlNetData],
ip_adapter_data=ip_adapter_data, # IPAdapterData,
callback=step_callback,
)

View File

@ -54,7 +54,14 @@ class DivideInvocation(BaseInvocation):
return IntegerOutput(value=int(self.a / self.b))
@invocation("rand_int", title="Random Integer", tags=["math", "random"], category="math", version="1.0.0")
@invocation(
"rand_int",
title="Random Integer",
tags=["math", "random"],
category="math",
version="1.0.0",
use_cache=False,
)
class RandomIntInvocation(BaseInvocation):
"""Outputs a single random integer."""

View File

@ -95,9 +95,10 @@ class ONNXPromptInvocation(BaseInvocation):
print(f'Warn: trigger: "{trigger}" not found')
if loras or ti_list:
text_encoder.release_session()
with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras), ONNXModelPatcher.apply_ti(
orig_tokenizer, text_encoder, ti_list
) as (tokenizer, ti_manager):
with (
ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras),
ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager),
):
text_encoder.create_session()
# copy from

View File

@ -10,7 +10,14 @@ from invokeai.app.invocations.primitives import StringCollectionOutput
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, invocation
@invocation("dynamic_prompt", title="Dynamic Prompt", tags=["prompt", "collection"], category="prompt", version="1.0.0")
@invocation(
"dynamic_prompt",
title="Dynamic Prompt",
tags=["prompt", "collection"],
category="prompt",
version="1.0.0",
use_cache=False,
)
class DynamicPromptInvocation(BaseInvocation):
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""

View File

@ -53,24 +53,20 @@ class BoardImageRecordStorageBase(ABC):
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
self._conn = conn
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
self._lock = lock
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:

View File

@ -1,6 +1,5 @@
import sqlite3
import threading
import uuid
from abc import ABC, abstractmethod
from typing import Optional, Union, cast
@ -8,6 +7,7 @@ from pydantic import BaseModel, Extra, Field
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.board_record import BoardRecord, deserialize_board_record
from invokeai.app.util.misc import uuid_string
class BoardChanges(BaseModel, extra=Extra.forbid):
@ -87,24 +87,20 @@ class BoardRecordStorageBase(ABC):
class SqliteBoardRecordStorage(BoardRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
self._conn = conn
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
self._lock = lock
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
@ -174,7 +170,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
board_name: str,
) -> BoardRecord:
try:
board_id = str(uuid.uuid4())
board_id = uuid_string()
self._lock.acquire()
self._cursor.execute(
"""--sql

View File

@ -16,7 +16,7 @@ import pydoc
import sys
from argparse import ArgumentParser
from pathlib import Path
from typing import ClassVar, Dict, List, Literal, Union, get_args, get_origin, get_type_hints
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
from omegaconf import DictConfig, ListConfig, OmegaConf
from pydantic import BaseSettings
@ -39,10 +39,10 @@ class InvokeAISettings(BaseSettings):
read from an omegaconf .yaml file.
"""
initconf: ClassVar[DictConfig] = None
initconf: ClassVar[Optional[DictConfig]] = None
argparse_groups: ClassVar[Dict] = {}
def parse_args(self, argv: list = sys.argv[1:]):
def parse_args(self, argv: Optional[list] = sys.argv[1:]):
parser = self.get_parser()
opt, unknown_opts = parser.parse_known_args(argv)
if len(unknown_opts) > 0:
@ -83,7 +83,8 @@ class InvokeAISettings(BaseSettings):
else:
settings_stanza = "Uncategorized"
env_prefix = cls.Config.env_prefix if hasattr(cls.Config, "env_prefix") else settings_stanza.upper()
env_prefix = getattr(cls.Config, "env_prefix", None)
env_prefix = env_prefix if env_prefix is not None else settings_stanza.upper()
initconf = (
cls.initconf.get(settings_stanza)
@ -116,8 +117,8 @@ class InvokeAISettings(BaseSettings):
field.default = current_default
@classmethod
def cmd_name(self, command_field: str = "type") -> str:
hints = get_type_hints(self)
def cmd_name(cls, command_field: str = "type") -> str:
hints = get_type_hints(cls)
if command_field in hints:
return get_args(hints[command_field])[0]
else:
@ -133,16 +134,12 @@ class InvokeAISettings(BaseSettings):
return parser
@classmethod
def add_subparser(cls, parser: argparse.ArgumentParser):
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
@classmethod
def _excluded(self) -> List[str]:
def _excluded(cls) -> List[str]:
# internal fields that shouldn't be exposed as command line options
return ["type", "initconf"]
@classmethod
def _excluded_from_yaml(self) -> List[str]:
def _excluded_from_yaml(cls) -> List[str]:
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
return [
"type",

View File

@ -194,8 +194,8 @@ class InvokeAIAppConfig(InvokeAISettings):
setting environment variables INVOKEAI_<setting>.
"""
singleton_config: ClassVar[InvokeAIAppConfig] = None
singleton_init: ClassVar[Dict] = None
singleton_config: ClassVar[Optional[InvokeAIAppConfig]] = None
singleton_init: ClassVar[Optional[Dict]] = None
# fmt: off
type: Literal["InvokeAI"] = "InvokeAI"
@ -234,6 +234,7 @@ class InvokeAIAppConfig(InvokeAISettings):
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
log_sql : bool = Field(default=False, description="Log SQL queries", category="Logging")
dev_reload : bool = Field(default=False, description="Automatically reload when Python sources are changed.", category="Development")
@ -245,18 +246,23 @@ class InvokeAIAppConfig(InvokeAISettings):
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", )
# DEVICE
device : Literal[tuple(["auto", "cpu", "cuda", "cuda:1", "mps"])] = Field(default="auto", description="Generation device", category="Device", )
precision: Literal[tuple(["auto", "float16", "float32", "autocast"])] = Field(default="auto", description="Floating point precision", category="Device", )
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", category="Device", )
precision : Literal["auto", "float16", "float32", "autocast"] = Field(default="auto", description="Floating point precision", category="Device", )
# GENERATION
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category="Generation", )
attention_type : Literal[tuple(["auto", "normal", "xformers", "sliced", "torch-sdp"])] = Field(default="auto", description="Attention type", category="Generation", )
attention_slice_size: Literal[tuple(["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8])] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type", category="Generation", )
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
# QUEUE
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", category="Queue", )
# NODES
allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", category="Nodes")
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", category="Nodes")
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", category="Nodes", )
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
@ -272,7 +278,7 @@ class InvokeAIAppConfig(InvokeAISettings):
class Config:
validate_assignment = True
def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False):
def parse_args(self, argv: Optional[list[str]] = None, conf: Optional[DictConfig] = None, clobber=False):
"""
Update settings with contents of init file, environment, and
command-line settings.
@ -283,12 +289,16 @@ class InvokeAIAppConfig(InvokeAISettings):
# Set the runtime root directory. We parse command-line switches here
# in order to pick up the --root_dir option.
super().parse_args(argv)
loaded_conf = None
if conf is None:
try:
conf = OmegaConf.load(self.root_dir / INIT_FILE)
loaded_conf = OmegaConf.load(self.root_dir / INIT_FILE)
except Exception:
pass
InvokeAISettings.initconf = conf
if isinstance(loaded_conf, DictConfig):
InvokeAISettings.initconf = loaded_conf
else:
InvokeAISettings.initconf = conf
# parse args again in order to pick up settings in configuration file
super().parse_args(argv)
@ -376,13 +386,6 @@ class InvokeAIAppConfig(InvokeAISettings):
"""
return self._resolve(self.models_dir)
@property
def autoconvert_path(self) -> Path:
"""
Path to the directory containing models to be imported automatically at startup.
"""
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
# the following methods support legacy calls leftover from the Globals era
@property
def full_precision(self) -> bool:
@ -405,11 +408,11 @@ class InvokeAIAppConfig(InvokeAISettings):
return True
@property
def ram_cache_size(self) -> float:
def ram_cache_size(self) -> Union[Literal["auto"], float]:
return self.max_cache_size or self.ram
@property
def vram_cache_size(self) -> float:
def vram_cache_size(self) -> Union[Literal["auto"], float]:
return self.max_vram_cache_size or self.vram
@property

View File

@ -10,57 +10,58 @@ default_text_to_image_graph_id = "539b2af5-2b4d-4d8c-8071-e54a3255fc74"
def create_text_to_image() -> LibraryGraph:
graph = Graph(
nodes={
"width": IntegerInvocation(id="width", value=512),
"height": IntegerInvocation(id="height", value=512),
"seed": IntegerInvocation(id="seed", value=-1),
"3": NoiseInvocation(id="3"),
"4": CompelInvocation(id="4"),
"5": CompelInvocation(id="5"),
"6": DenoiseLatentsInvocation(id="6"),
"7": LatentsToImageInvocation(id="7"),
"8": ImageNSFWBlurInvocation(id="8"),
},
edges=[
Edge(
source=EdgeConnection(node_id="width", field="value"),
destination=EdgeConnection(node_id="3", field="width"),
),
Edge(
source=EdgeConnection(node_id="height", field="value"),
destination=EdgeConnection(node_id="3", field="height"),
),
Edge(
source=EdgeConnection(node_id="seed", field="value"),
destination=EdgeConnection(node_id="3", field="seed"),
),
Edge(
source=EdgeConnection(node_id="3", field="noise"),
destination=EdgeConnection(node_id="6", field="noise"),
),
Edge(
source=EdgeConnection(node_id="6", field="latents"),
destination=EdgeConnection(node_id="7", field="latents"),
),
Edge(
source=EdgeConnection(node_id="4", field="conditioning"),
destination=EdgeConnection(node_id="6", field="positive_conditioning"),
),
Edge(
source=EdgeConnection(node_id="5", field="conditioning"),
destination=EdgeConnection(node_id="6", field="negative_conditioning"),
),
Edge(
source=EdgeConnection(node_id="7", field="image"),
destination=EdgeConnection(node_id="8", field="image"),
),
],
)
return LibraryGraph(
id=default_text_to_image_graph_id,
name="t2i",
description="Converts text to an image",
graph=Graph(
nodes={
"width": IntegerInvocation(id="width", value=512),
"height": IntegerInvocation(id="height", value=512),
"seed": IntegerInvocation(id="seed", value=-1),
"3": NoiseInvocation(id="3"),
"4": CompelInvocation(id="4"),
"5": CompelInvocation(id="5"),
"6": DenoiseLatentsInvocation(id="6"),
"7": LatentsToImageInvocation(id="7"),
"8": ImageNSFWBlurInvocation(id="8"),
},
edges=[
Edge(
source=EdgeConnection(node_id="width", field="value"),
destination=EdgeConnection(node_id="3", field="width"),
),
Edge(
source=EdgeConnection(node_id="height", field="value"),
destination=EdgeConnection(node_id="3", field="height"),
),
Edge(
source=EdgeConnection(node_id="seed", field="value"),
destination=EdgeConnection(node_id="3", field="seed"),
),
Edge(
source=EdgeConnection(node_id="3", field="noise"),
destination=EdgeConnection(node_id="6", field="noise"),
),
Edge(
source=EdgeConnection(node_id="6", field="latents"),
destination=EdgeConnection(node_id="7", field="latents"),
),
Edge(
source=EdgeConnection(node_id="4", field="conditioning"),
destination=EdgeConnection(node_id="6", field="positive_conditioning"),
),
Edge(
source=EdgeConnection(node_id="5", field="conditioning"),
destination=EdgeConnection(node_id="6", field="negative_conditioning"),
),
Edge(
source=EdgeConnection(node_id="7", field="image"),
destination=EdgeConnection(node_id="8", field="image"),
),
],
),
graph=graph,
exposed_inputs=[
ExposedNodeInput(node_path="4", field="prompt", alias="positive_prompt"),
ExposedNodeInput(node_path="5", field="prompt", alias="negative_prompt"),

View File

@ -4,21 +4,23 @@ from typing import Any, Optional
from invokeai.app.models.image import ProgressImage
from invokeai.app.services.model_manager_service import BaseModelType, ModelInfo, ModelType, SubModelType
from invokeai.app.services.session_queue.session_queue_common import EnqueueBatchResult, SessionQueueItem
from invokeai.app.util.misc import get_timestamp
class EventServiceBase:
session_event: str = "session_event"
queue_event: str = "queue_event"
"""Basic event bus, to have an empty stand-in when not needed"""
def dispatch(self, event_name: str, payload: Any) -> None:
pass
def __emit_session_event(self, event_name: str, payload: dict) -> None:
def __emit_queue_event(self, event_name: str, payload: dict) -> None:
"""Queue events are emitted to a room with queue_id as the room name"""
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.session_event,
event_name=EventServiceBase.queue_event,
payload=dict(event=event_name, data=payload),
)
@ -26,6 +28,9 @@ class EventServiceBase:
# This will make them easier to integrate until we find a schema generator.
def emit_generator_progress(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
@ -35,11 +40,14 @@ class EventServiceBase:
total_steps: int,
) -> None:
"""Emitted when there is generation progress"""
self.__emit_session_event(
self.__emit_queue_event(
event_name="generator_progress",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
node=node,
node_id=node.get("id"),
source_node_id=source_node_id,
progress_image=progress_image.dict() if progress_image is not None else None,
step=step,
@ -50,15 +58,21 @@ class EventServiceBase:
def emit_invocation_complete(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
result: dict,
node: dict,
source_node_id: str,
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_session_event(
self.__emit_queue_event(
event_name="invocation_complete",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
@ -68,6 +82,9 @@ class EventServiceBase:
def emit_invocation_error(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
@ -75,9 +92,12 @@ class EventServiceBase:
error: str,
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_session_event(
self.__emit_queue_event(
event_name="invocation_error",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
@ -86,28 +106,47 @@ class EventServiceBase:
),
)
def emit_invocation_started(self, graph_execution_state_id: str, node: dict, source_node_id: str) -> None:
def emit_invocation_started(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
) -> None:
"""Emitted when an invocation has started"""
self.__emit_session_event(
self.__emit_queue_event(
event_name="invocation_started",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
),
)
def emit_graph_execution_complete(self, graph_execution_state_id: str) -> None:
def emit_graph_execution_complete(
self, queue_id: str, queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str
) -> None:
"""Emitted when a session has completed all invocations"""
self.__emit_session_event(
self.__emit_queue_event(
event_name="graph_execution_state_complete",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
),
)
def emit_model_load_started(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
model_name: str,
base_model: BaseModelType,
@ -115,9 +154,12 @@ class EventServiceBase:
submodel: SubModelType,
) -> None:
"""Emitted when a model is requested"""
self.__emit_session_event(
self.__emit_queue_event(
event_name="model_load_started",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
model_name=model_name,
base_model=base_model,
@ -128,6 +170,9 @@ class EventServiceBase:
def emit_model_load_completed(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
model_name: str,
base_model: BaseModelType,
@ -136,9 +181,12 @@ class EventServiceBase:
model_info: ModelInfo,
) -> None:
"""Emitted when a model is correctly loaded (returns model info)"""
self.__emit_session_event(
self.__emit_queue_event(
event_name="model_load_completed",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
model_name=model_name,
base_model=base_model,
@ -152,14 +200,20 @@ class EventServiceBase:
def emit_session_retrieval_error(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
error_type: str,
error: str,
) -> None:
"""Emitted when session retrieval fails"""
self.__emit_session_event(
self.__emit_queue_event(
event_name="session_retrieval_error",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
error_type=error_type,
error=error,
@ -168,18 +222,78 @@ class EventServiceBase:
def emit_invocation_retrieval_error(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node_id: str,
error_type: str,
error: str,
) -> None:
"""Emitted when invocation retrieval fails"""
self.__emit_session_event(
self.__emit_queue_event(
event_name="invocation_retrieval_error",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
node_id=node_id,
error_type=error_type,
error=error,
),
)
def emit_session_canceled(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
) -> None:
"""Emitted when a session is canceled"""
self.__emit_queue_event(
event_name="session_canceled",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
),
)
def emit_queue_item_status_changed(self, session_queue_item: SessionQueueItem) -> None:
"""Emitted when a queue item's status changes"""
self.__emit_queue_event(
event_name="queue_item_status_changed",
payload=dict(
queue_id=session_queue_item.queue_id,
queue_item_id=session_queue_item.item_id,
status=session_queue_item.status,
batch_id=session_queue_item.batch_id,
session_id=session_queue_item.session_id,
error=session_queue_item.error,
created_at=str(session_queue_item.created_at) if session_queue_item.created_at else None,
updated_at=str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
started_at=str(session_queue_item.started_at) if session_queue_item.started_at else None,
completed_at=str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
),
)
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
"""Emitted when a batch is enqueued"""
self.__emit_queue_event(
event_name="batch_enqueued",
payload=dict(
queue_id=enqueue_result.queue_id,
batch_id=enqueue_result.batch.batch_id,
enqueued=enqueue_result.enqueued,
),
)
def emit_queue_cleared(self, queue_id: str) -> None:
"""Emitted when the queue is cleared"""
self.__emit_queue_event(
event_name="queue_cleared",
payload=dict(queue_id=queue_id),
)

View File

@ -2,13 +2,14 @@
import copy
import itertools
import uuid
from typing import Annotated, Any, Optional, Union, get_args, get_origin, get_type_hints
from typing import Annotated, Any, Optional, Union, cast, get_args, get_origin, get_type_hints
import networkx as nx
from pydantic import BaseModel, root_validator, validator
from pydantic.fields import Field
from invokeai.app.util.misc import uuid_string
# Importing * is bad karma but needed here for node detection
from ..invocations import * # noqa: F401 F403
from ..invocations.baseinvocation import (
@ -137,19 +138,31 @@ def are_connections_compatible(
return are_connection_types_compatible(from_node_field, to_node_field)
class NodeAlreadyInGraphError(Exception):
class NodeAlreadyInGraphError(ValueError):
pass
class InvalidEdgeError(Exception):
class InvalidEdgeError(ValueError):
pass
class NodeNotFoundError(Exception):
class NodeNotFoundError(ValueError):
pass
class NodeAlreadyExecutedError(Exception):
class NodeAlreadyExecutedError(ValueError):
pass
class DuplicateNodeIdError(ValueError):
pass
class NodeFieldNotFoundError(ValueError):
pass
class NodeIdMismatchError(ValueError):
pass
@ -227,7 +240,7 @@ InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()]
class Graph(BaseModel):
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
id: str = Field(description="The id of this graph", default_factory=uuid_string)
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
description="The nodes in this graph", default_factory=dict
@ -237,6 +250,59 @@ class Graph(BaseModel):
default_factory=list,
)
@root_validator
def validate_nodes_and_edges(cls, values):
"""Validates that all edges match nodes in the graph"""
nodes = cast(Optional[dict[str, BaseInvocation]], values.get("nodes"))
edges = cast(Optional[list[Edge]], values.get("edges"))
if nodes is not None:
# Validate that all node ids are unique
node_ids = [n.id for n in nodes.values()]
duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2])
if duplicate_node_ids:
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
# Validate that all node ids match the keys in the nodes dict
for k, v in nodes.items():
if k != v.id:
raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}")
if edges is not None and nodes is not None:
# Validate that all edges match nodes in the graph
node_ids = set([e.source.node_id for e in edges] + [e.destination.node_id for e in edges])
missing_node_ids = [node_id for node_id in node_ids if node_id not in nodes]
if missing_node_ids:
raise NodeNotFoundError(
f"All edges must reference nodes in the graph, missing nodes: {missing_node_ids}"
)
# Validate that all edge fields match node fields in the graph
for edge in edges:
source_node = nodes.get(edge.source.node_id, None)
if source_node is None:
raise NodeFieldNotFoundError(f"Edge source node {edge.source.node_id} does not exist in the graph")
destination_node = nodes.get(edge.destination.node_id, None)
if destination_node is None:
raise NodeFieldNotFoundError(
f"Edge destination node {edge.destination.node_id} does not exist in the graph"
)
# output fields are not on the node object directly, they are on the output type
if edge.source.field not in source_node.get_output_type().__fields__:
raise NodeFieldNotFoundError(
f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}"
)
# input fields are on the node
if edge.destination.field not in destination_node.__fields__:
raise NodeFieldNotFoundError(
f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}"
)
return values
def add_node(self, node: BaseInvocation) -> None:
"""Adds a node to a graph
@ -697,8 +763,7 @@ class Graph(BaseModel):
class GraphExecutionState(BaseModel):
"""Tracks the state of a graph execution"""
id: str = Field(description="The id of the execution state", default_factory=lambda: uuid.uuid4().__str__())
id: str = Field(description="The id of the execution state", default_factory=uuid_string)
# TODO: Store a reference to the graph instead of the actual graph?
graph: Graph = Field(description="The graph being executed")
@ -847,7 +912,7 @@ class GraphExecutionState(BaseModel):
new_node = copy.deepcopy(node)
# Create the node id (use a random uuid)
new_node.id = str(uuid.uuid4())
new_node.id = uuid_string()
# Set the iteration index for iteration invocations
if isinstance(new_node, IterateInvocation):
@ -1082,7 +1147,7 @@ class ExposedNodeOutput(BaseModel):
class LibraryGraph(BaseModel):
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid_string)
graph: Graph = Field(description="The graph")
name: str = Field(description="The name of the graph")
description: str = Field(description="The description of the graph")

View File

@ -148,24 +148,20 @@ class ImageRecordStorageBase(ABC):
class SqliteImageRecordStorage(ImageRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
self._conn = conn
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
self._lock = lock
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from logging import Logger
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Callable, Optional
from PIL.Image import Image as PILImageType
@ -38,6 +38,29 @@ if TYPE_CHECKING:
class ImageServiceABC(ABC):
"""High-level service for image management."""
_on_changed_callbacks: list[Callable[[ImageDTO], None]]
_on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None:
self._on_changed_callbacks = list()
self._on_deleted_callbacks = list()
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
"""Register a callback for when an image is changed"""
self._on_changed_callbacks.append(on_changed)
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
"""Register a callback for when an image is deleted"""
self._on_deleted_callbacks.append(on_deleted)
def _on_changed(self, item: ImageDTO) -> None:
for callback in self._on_changed_callbacks:
callback(item)
def _on_deleted(self, item_id: str) -> None:
for callback in self._on_deleted_callbacks:
callback(item_id)
@abstractmethod
def create(
self,
@ -161,6 +184,7 @@ class ImageService(ImageServiceABC):
_services: ImageServiceDependencies
def __init__(self, services: ImageServiceDependencies):
super().__init__()
self._services = services
def create(
@ -217,6 +241,7 @@ class ImageService(ImageServiceABC):
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, workflow=workflow)
image_dto = self.get_dto(image_name)
self._on_changed(image_dto)
return image_dto
except ImageRecordSaveException:
self._services.logger.error("Failed to save image record")
@ -235,7 +260,9 @@ class ImageService(ImageServiceABC):
) -> ImageDTO:
try:
self._services.image_records.update(image_name, changes)
return self.get_dto(image_name)
image_dto = self.get_dto(image_name)
self._on_changed(image_dto)
return image_dto
except ImageRecordSaveException:
self._services.logger.error("Failed to update image record")
raise
@ -374,6 +401,7 @@ class ImageService(ImageServiceABC):
try:
self._services.image_files.delete(image_name)
self._services.image_records.delete(image_name)
self._on_deleted(image_name)
except ImageRecordDeleteException:
self._services.logger.error("Failed to delete image record")
raise
@ -390,6 +418,8 @@ class ImageService(ImageServiceABC):
for image_name in image_names:
self._services.image_files.delete(image_name)
self._services.image_records.delete_many(image_names)
for image_name in image_names:
self._on_deleted(image_name)
except ImageRecordDeleteException:
self._services.logger.error("Failed to delete image records")
raise
@ -406,6 +436,7 @@ class ImageService(ImageServiceABC):
count = len(image_names)
for image_name in image_names:
self._services.image_files.delete(image_name)
self._on_deleted(image_name)
return count
except ImageRecordDeleteException:
self._services.logger.error("Failed to delete image records")

View File

@ -0,0 +1,46 @@
from abc import ABC, abstractmethod
from typing import Optional, Union
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
class InvocationCacheBase(ABC):
"""
Base class for invocation caches.
When an invocation is executed, it is hashed and its output stored in the cache.
When new invocations are executed, if they are flagged with `use_cache`, they
will attempt to pull their value from the cache before executing.
Implementations should register for the `on_deleted` event of the `images` and `latents`
services, and delete any cached outputs that reference the deleted image or latent.
See the memory implementation for an example.
Implementations should respect the `node_cache_size` configuration value, and skip all
cache logic if the value is set to 0.
"""
@abstractmethod
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
"""Retrieves an invocation output from the cache"""
pass
@abstractmethod
def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None:
"""Stores an invocation output in the cache"""
pass
@abstractmethod
def delete(self, key: Union[int, str]) -> None:
"""Deleteds an invocation output from the cache"""
pass
@abstractmethod
def clear(self) -> None:
"""Clears the cache"""
pass
@abstractmethod
def create_key(self, invocation: BaseInvocation) -> int:
"""Gets the key for the invocation's cache item"""
pass

View File

@ -0,0 +1,81 @@
from queue import Queue
from typing import Optional, Union
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
from invokeai.app.services.invoker import Invoker
class MemoryInvocationCache(InvocationCacheBase):
__cache: dict[Union[int, str], tuple[BaseInvocationOutput, str]]
__max_cache_size: int
__cache_ids: Queue
__invoker: Invoker
def __init__(self, max_cache_size: int = 0) -> None:
self.__cache = dict()
self.__max_cache_size = max_cache_size
self.__cache_ids = Queue()
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
if self.__max_cache_size == 0:
return
self.__invoker.services.images.on_deleted(self._delete_by_match)
self.__invoker.services.latents.on_deleted(self._delete_by_match)
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
if self.__max_cache_size == 0:
return
item = self.__cache.get(key, None)
if item is not None:
return item[0]
def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None:
if self.__max_cache_size == 0:
return
if key not in self.__cache:
self.__cache[key] = (invocation_output, invocation_output.json())
self.__cache_ids.put(key)
if self.__cache_ids.qsize() > self.__max_cache_size:
try:
self.__cache.pop(self.__cache_ids.get())
except KeyError:
# this means the cache_ids are somehow out of sync w/ the cache
pass
def delete(self, key: Union[int, str]) -> None:
if self.__max_cache_size == 0:
return
if key in self.__cache:
del self.__cache[key]
def clear(self, *args, **kwargs) -> None:
if self.__max_cache_size == 0:
return
self.__cache.clear()
self.__cache_ids = Queue()
def create_key(self, invocation: BaseInvocation) -> int:
return hash(invocation.json(exclude={"id"}))
def _delete_by_match(self, to_match: str) -> None:
if self.__max_cache_size == 0:
return
keys_to_delete = set()
for key, value_tuple in self.__cache.items():
if to_match in value_tuple[1]:
keys_to_delete.add(key)
if not keys_to_delete:
return
for key in keys_to_delete:
self.delete(key)
self.__invoker.services.logger.debug(f"Deleted {len(keys_to_delete)} cached invocation outputs for {to_match}")

View File

@ -11,6 +11,13 @@ from pydantic import BaseModel, Field
class InvocationQueueItem(BaseModel):
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
invocation_id: str = Field(description="The ID of the node being invoked")
session_queue_id: str = Field(description="The ID of the session queue from which this invocation queue item came")
session_queue_item_id: int = Field(
description="The ID of session queue item from which this invocation queue item came"
)
session_queue_batch_id: str = Field(
description="The ID of the session batch from which this invocation queue item came"
)
invoke_all: bool = Field(default=False)
timestamp: float = Field(default_factory=time.time)

View File

@ -12,12 +12,15 @@ if TYPE_CHECKING:
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
from invokeai.app.services.images import ImageServiceABC
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
from invokeai.app.services.invocation_queue import InvocationQueueABC
from invokeai.app.services.invocation_stats import InvocationStatsServiceBase
from invokeai.app.services.invoker import InvocationProcessorABC
from invokeai.app.services.item_storage import ItemStorageABC
from invokeai.app.services.latent_storage import LatentsStorageBase
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
class InvocationServices:
@ -28,8 +31,8 @@ class InvocationServices:
boards: "BoardServiceABC"
configuration: "InvokeAIAppConfig"
events: "EventServiceBase"
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
graph_library: "ItemStorageABC"["LibraryGraph"]
graph_execution_manager: "ItemStorageABC[GraphExecutionState]"
graph_library: "ItemStorageABC[LibraryGraph]"
images: "ImageServiceABC"
latents: "LatentsStorageBase"
logger: "Logger"
@ -37,6 +40,9 @@ class InvocationServices:
processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase"
queue: "InvocationQueueABC"
session_queue: "SessionQueueBase"
session_processor: "SessionProcessorBase"
invocation_cache: "InvocationCacheBase"
def __init__(
self,
@ -44,8 +50,8 @@ class InvocationServices:
boards: "BoardServiceABC",
configuration: "InvokeAIAppConfig",
events: "EventServiceBase",
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
graph_library: "ItemStorageABC"["LibraryGraph"],
graph_execution_manager: "ItemStorageABC[GraphExecutionState]",
graph_library: "ItemStorageABC[LibraryGraph]",
images: "ImageServiceABC",
latents: "LatentsStorageBase",
logger: "Logger",
@ -53,10 +59,12 @@ class InvocationServices:
processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC",
session_queue: "SessionQueueBase",
session_processor: "SessionProcessorBase",
invocation_cache: "InvocationCacheBase",
):
self.board_images = board_images
self.boards = boards
self.boards = boards
self.configuration = configuration
self.events = events
self.graph_execution_manager = graph_execution_manager
@ -68,3 +76,6 @@ class InvocationServices:
self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue
self.session_queue = session_queue
self.session_processor = session_processor
self.invocation_cache = invocation_cache

View File

@ -17,7 +17,14 @@ class Invoker:
self.services = services
self._start()
def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> Optional[str]:
def invoke(
self,
session_queue_id: str,
session_queue_item_id: int,
session_queue_batch_id: str,
graph_execution_state: GraphExecutionState,
invoke_all: bool = False,
) -> Optional[str]:
"""Determines the next node to invoke and enqueues it, preparing if needed.
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
@ -32,7 +39,9 @@ class Invoker:
# Queue the invocation
self.services.queue.put(
InvocationQueueItem(
# session_id = session.id,
session_queue_id=session_queue_id,
session_queue_item_id=session_queue_item_id,
session_queue_batch_id=session_queue_batch_id,
graph_execution_state_id=graph_execution_state.id,
invocation_id=invocation.id,
invoke_all=invoke_all,

View File

@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
from typing import Dict, Optional, Union
from typing import Callable, Dict, Optional, Union
import torch
@ -11,6 +11,13 @@ import torch
class LatentsStorageBase(ABC):
"""Responsible for storing and retrieving latents."""
_on_changed_callbacks: list[Callable[[torch.Tensor], None]]
_on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None:
self._on_changed_callbacks = list()
self._on_deleted_callbacks = list()
@abstractmethod
def get(self, name: str) -> torch.Tensor:
pass
@ -23,6 +30,22 @@ class LatentsStorageBase(ABC):
def delete(self, name: str) -> None:
pass
def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None:
"""Register a callback for when an item is changed"""
self._on_changed_callbacks.append(on_changed)
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
"""Register a callback for when an item is deleted"""
self._on_deleted_callbacks.append(on_deleted)
def _on_changed(self, item: torch.Tensor) -> None:
for callback in self._on_changed_callbacks:
callback(item)
def _on_deleted(self, item_id: str) -> None:
for callback in self._on_deleted_callbacks:
callback(item_id)
class ForwardCacheLatentsStorage(LatentsStorageBase):
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
@ -33,6 +56,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
__underlying_storage: LatentsStorageBase
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
super().__init__()
self.__underlying_storage = underlying_storage
self.__cache = dict()
self.__cache_ids = Queue()
@ -50,11 +74,13 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
def save(self, name: str, data: torch.Tensor) -> None:
self.__underlying_storage.save(name, data)
self.__set_cache(name, data)
self._on_changed(data)
def delete(self, name: str) -> None:
self.__underlying_storage.delete(name)
if name in self.__cache:
del self.__cache[name]
self._on_deleted(name)
def __get_cache(self, name: str) -> Optional[torch.Tensor]:
return None if name not in self.__cache else self.__cache[name]

View File

@ -525,7 +525,7 @@ class ModelManagerService(ModelManagerServiceBase):
def _emit_load_event(
self,
context,
context: InvocationContext,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
@ -537,6 +537,9 @@ class ModelManagerService(ModelManagerServiceBase):
if model_info:
context.services.events.emit_model_load_completed(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
model_name=model_name,
base_model=base_model,
@ -546,6 +549,9 @@ class ModelManagerService(ModelManagerServiceBase):
)
else:
context.services.events.emit_model_load_started(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
model_name=model_name,
base_model=base_model,

View File

@ -1,6 +1,7 @@
import time
import traceback
from threading import BoundedSemaphore, Event, Thread
from typing import Optional
import invokeai.backend.util.logging as logger
@ -37,10 +38,11 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
try:
self.__threadLimit.acquire()
statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
queue_item: Optional[InvocationQueueItem] = None
while not stop_event.is_set():
try:
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
queue_item = self.__invoker.services.queue.get()
except Exception as e:
self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e)
@ -48,7 +50,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# do not hammer the queue
time.sleep(0.5)
continue
try:
graph_execution_state = self.__invoker.services.graph_execution_manager.get(
queue_item.graph_execution_state_id
@ -56,6 +57,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
except Exception as e:
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
self.__invoker.services.events.emit_session_retrieval_error(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=queue_item.graph_execution_state_id,
error_type=e.__class__.__name__,
error=traceback.format_exc(),
@ -67,6 +71,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
except Exception as e:
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
self.__invoker.services.events.emit_invocation_retrieval_error(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=queue_item.graph_execution_state_id,
node_id=queue_item.invocation_id,
error_type=e.__class__.__name__,
@ -79,6 +86,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Send starting event
self.__invoker.services.events.emit_invocation_started(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
@ -89,13 +99,17 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
graph_id = graph_execution_state.id
model_manager = self.__invoker.services.model_manager
with statistics.collect_stats(invocation, graph_id, model_manager):
# use the internal invoke_internal(), which wraps the node's invoke() method in
# this accomodates nodes which require a value, but get it only from a
# connection
# use the internal invoke_internal(), which wraps the node's invoke() method,
# which handles a few things:
# - nodes that require a value, but get it only from a connection
# - referencing the invocation cache instead of executing the node
outputs = invocation.invoke_internal(
InvocationContext(
services=self.__invoker.services,
graph_execution_state_id=graph_execution_state.id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
queue_batch_id=queue_item.session_queue_batch_id,
)
)
@ -111,6 +125,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Send complete event
self.__invoker.services.events.emit_invocation_complete(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
@ -138,6 +155,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
# Send error event
self.__invoker.services.events.emit_invocation_error(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
@ -155,10 +175,19 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
is_complete = graph_execution_state.is_complete()
if queue_item.invoke_all and not is_complete:
try:
self.__invoker.invoke(graph_execution_state, invoke_all=True)
self.__invoker.invoke(
session_queue_batch_id=queue_item.session_queue_batch_id,
session_queue_item_id=queue_item.session_queue_item_id,
session_queue_id=queue_item.session_queue_id,
graph_execution_state=graph_execution_state,
invoke_all=True,
)
except Exception as e:
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
self.__invoker.services.events.emit_invocation_error(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
@ -166,7 +195,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
error=traceback.format_exc(),
)
elif is_complete:
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)
self.__invoker.services.events.emit_graph_execution_complete(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
)
except KeyboardInterrupt:
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor

View File

@ -1,7 +1,8 @@
import uuid
from abc import ABC, abstractmethod
from enum import Enum, EnumMeta
from invokeai.app.util.misc import uuid_string
class ResourceType(str, Enum, metaclass=EnumMeta):
"""Enum for resource types."""
@ -25,6 +26,6 @@ class SimpleNameService(NameServiceBase):
# TODO: Add customizable naming schemes
def create_image_name(self) -> str:
uuid_str = str(uuid.uuid4())
uuid_str = uuid_string()
filename = f"{uuid_str}.png"
return filename

View File

@ -0,0 +1,28 @@
from abc import ABC, abstractmethod
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
class SessionProcessorBase(ABC):
"""
Base class for session processor.
The session processor is responsible for executing sessions. It runs a simple polling loop,
checking the session queue for new sessions to execute. It must coordinate with the
invocation queue to ensure only one session is executing at a time.
"""
@abstractmethod
def resume(self) -> SessionProcessorStatus:
"""Starts or resumes the session processor"""
pass
@abstractmethod
def pause(self) -> SessionProcessorStatus:
"""Pauses the session processor"""
pass
@abstractmethod
def get_status(self) -> SessionProcessorStatus:
"""Gets the status of the session processor"""
pass

View File

@ -0,0 +1,6 @@
from pydantic import BaseModel, Field
class SessionProcessorStatus(BaseModel):
is_started: bool = Field(description="Whether the session processor is started")
is_processing: bool = Field(description="Whether a session is being processed")

View File

@ -0,0 +1,124 @@
from threading import BoundedSemaphore
from threading import Event as ThreadEvent
from threading import Thread
from typing import Optional
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from ..invoker import Invoker
from .session_processor_base import SessionProcessorBase
from .session_processor_common import SessionProcessorStatus
POLLING_INTERVAL = 1
THREAD_LIMIT = 1
class DefaultSessionProcessor(SessionProcessorBase):
def start(self, invoker: Invoker) -> None:
self.__invoker: Invoker = invoker
self.__queue_item: Optional[SessionQueueItem] = None
self.__resume_event = ThreadEvent()
self.__stop_event = ThreadEvent()
self.__poll_now_event = ThreadEvent()
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
self.__threadLimit = BoundedSemaphore(THREAD_LIMIT)
self.__thread = Thread(
name="session_processor",
target=self.__process,
kwargs=dict(
stop_event=self.__stop_event, poll_now_event=self.__poll_now_event, resume_event=self.__resume_event
),
)
self.__thread.start()
def stop(self, *args, **kwargs) -> None:
self.__stop_event.set()
def _poll_now(self) -> None:
self.__poll_now_event.set()
async def _on_queue_event(self, event: FastAPIEvent) -> None:
event_name = event[1]["event"]
match event_name:
case "graph_execution_state_complete" | "invocation_error" | "session_retrieval_error" | "invocation_retrieval_error":
self.__queue_item = None
self._poll_now()
case "session_canceled" if self.__queue_item is not None and self.__queue_item.session_id == event[1][
"data"
]["graph_execution_state_id"]:
self.__queue_item = None
self._poll_now()
case "batch_enqueued":
self._poll_now()
case "queue_cleared":
self.__queue_item = None
self._poll_now()
def resume(self) -> SessionProcessorStatus:
if not self.__resume_event.is_set():
self.__resume_event.set()
return self.get_status()
def pause(self) -> SessionProcessorStatus:
if self.__resume_event.is_set():
self.__resume_event.clear()
return self.get_status()
def get_status(self) -> SessionProcessorStatus:
return SessionProcessorStatus(
is_started=self.__resume_event.is_set(),
is_processing=self.__queue_item is not None,
)
def __process(
self,
stop_event: ThreadEvent,
poll_now_event: ThreadEvent,
resume_event: ThreadEvent,
):
try:
stop_event.clear()
resume_event.set()
self.__threadLimit.acquire()
queue_item: Optional[SessionQueueItem] = None
self.__invoker.services.logger
while not stop_event.is_set():
poll_now_event.clear()
# do not dequeue if there is already a session running
if self.__queue_item is None and resume_event.is_set():
queue_item = self.__invoker.services.session_queue.dequeue()
if queue_item is not None:
self.__invoker.services.logger.debug(f"Executing queue item {queue_item.item_id}")
self.__queue_item = queue_item
self.__invoker.services.graph_execution_manager.set(queue_item.session)
self.__invoker.invoke(
session_queue_batch_id=queue_item.batch_id,
session_queue_id=queue_item.queue_id,
session_queue_item_id=queue_item.item_id,
graph_execution_state=queue_item.session,
invoke_all=True,
)
queue_item = None
if queue_item is None:
self.__invoker.services.logger.debug("Waiting for next polling interval or event")
poll_now_event.wait(POLLING_INTERVAL)
continue
except Exception as e:
self.__invoker.services.logger.error(f"Error in session processor: {e}")
pass
finally:
stop_event.clear()
poll_now_event.clear()
self.__queue_item = None
self.__threadLimit.release()

View File

@ -0,0 +1,112 @@
from abc import ABC, abstractmethod
from typing import Optional
from invokeai.app.services.graph import Graph
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
Batch,
BatchStatus,
CancelByBatchIDsResult,
CancelByQueueIDResult,
ClearResult,
EnqueueBatchResult,
EnqueueGraphResult,
IsEmptyResult,
IsFullResult,
PruneResult,
SessionQueueItem,
SessionQueueItemDTO,
SessionQueueStatus,
)
from invokeai.app.services.shared.models import CursorPaginatedResults
class SessionQueueBase(ABC):
"""Base class for session queue"""
@abstractmethod
def dequeue(self) -> Optional[SessionQueueItem]:
"""Dequeues the next session queue item."""
pass
@abstractmethod
def enqueue_graph(self, queue_id: str, graph: Graph, prepend: bool) -> EnqueueGraphResult:
"""Enqueues a single graph for execution."""
pass
@abstractmethod
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
"""Enqueues all permutations of a batch for execution."""
pass
@abstractmethod
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
"""Gets the currently-executing session queue item"""
pass
@abstractmethod
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
"""Gets the next session queue item (does not dequeue it)"""
pass
@abstractmethod
def clear(self, queue_id: str) -> ClearResult:
"""Deletes all session queue items"""
pass
@abstractmethod
def prune(self, queue_id: str) -> PruneResult:
"""Deletes all completed and errored session queue items"""
pass
@abstractmethod
def is_empty(self, queue_id: str) -> IsEmptyResult:
"""Checks if the queue is empty"""
pass
@abstractmethod
def is_full(self, queue_id: str) -> IsFullResult:
"""Checks if the queue is empty"""
pass
@abstractmethod
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
"""Gets the status of the queue"""
pass
@abstractmethod
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
"""Gets the status of a batch"""
pass
@abstractmethod
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
"""Cancels a session queue item"""
pass
@abstractmethod
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
"""Cancels all queue items with matching batch IDs"""
pass
@abstractmethod
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
"""Cancels all queue items with matching queue ID"""
pass
@abstractmethod
def list_queue_items(
self,
queue_id: str,
limit: int,
priority: int,
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
) -> CursorPaginatedResults[SessionQueueItemDTO]:
"""Gets a page of session queue items"""
pass
@abstractmethod
def get_queue_item(self, item_id: int) -> SessionQueueItem:
"""Gets a session queue item by ID"""
pass

View File

@ -0,0 +1,418 @@
import datetime
import json
from itertools import chain, product
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
from pydantic import BaseModel, Field, StrictStr, parse_raw_as, root_validator, validator
from pydantic.json import pydantic_encoder
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.graph import Graph, GraphExecutionState, NodeNotFoundError
from invokeai.app.util.misc import uuid_string
# region Errors
class BatchZippedLengthError(ValueError):
"""Raise when a batch has items of different lengths."""
class BatchItemsTypeError(TypeError):
"""Raise when a batch has items of different types."""
class BatchDuplicateNodeFieldError(ValueError):
"""Raise when a batch has duplicate node_path and field_name."""
class TooManySessionsError(ValueError):
"""Raise when too many sessions are requested."""
class SessionQueueItemNotFoundError(ValueError):
"""Raise when a queue item is not found."""
# endregion
# region Batch
BatchDataType = Union[
StrictStr,
float,
int,
]
class NodeFieldValue(BaseModel):
node_path: str = Field(description="The node into which this batch data item will be substituted.")
field_name: str = Field(description="The field into which this batch data item will be substituted.")
value: BatchDataType = Field(description="The value to substitute into the node/field.")
class BatchDatum(BaseModel):
node_path: str = Field(description="The node into which this batch data collection will be substituted.")
field_name: str = Field(description="The field into which this batch data collection will be substituted.")
items: list[BatchDataType] = Field(
default_factory=list, description="The list of items to substitute into the node/field."
)
BatchDataCollection: TypeAlias = list[list[BatchDatum]]
class Batch(BaseModel):
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
graph: Graph = Field(description="The graph to initialize the session with")
runs: int = Field(
default=1, ge=1, description="Int stating how many times to iterate through all possible batch indices"
)
@validator("data")
def validate_lengths(cls, v: Optional[BatchDataCollection]):
if v is None:
return v
for batch_data_list in v:
first_item_length = len(batch_data_list[0].items) if batch_data_list and batch_data_list[0].items else 0
for i in batch_data_list:
if len(i.items) != first_item_length:
raise BatchZippedLengthError("Zipped batch items must all have the same length")
return v
@validator("data")
def validate_types(cls, v: Optional[BatchDataCollection]):
if v is None:
return v
for batch_data_list in v:
for datum in batch_data_list:
# Get the type of the first item in the list
first_item_type = type(datum.items[0]) if datum.items else None
for item in datum.items:
if type(item) is not first_item_type:
raise BatchItemsTypeError("All items in a batch must have the same type")
return v
@validator("data")
def validate_unique_field_mappings(cls, v: Optional[BatchDataCollection]):
if v is None:
return v
paths: set[tuple[str, str]] = set()
for batch_data_list in v:
for datum in batch_data_list:
pair = (datum.node_path, datum.field_name)
if pair in paths:
raise BatchDuplicateNodeFieldError("Each batch data must have unique node_id and field_name")
paths.add(pair)
return v
@root_validator(skip_on_failure=True)
def validate_batch_nodes_and_edges(cls, values):
batch_data_collection = cast(Optional[BatchDataCollection], values["data"])
if batch_data_collection is None:
return values
graph = cast(Graph, values["graph"])
for batch_data_list in batch_data_collection:
for batch_data in batch_data_list:
try:
node = cast(BaseInvocation, graph.get_node(batch_data.node_path))
except NodeNotFoundError:
raise NodeNotFoundError(f"Node {batch_data.node_path} not found in graph")
if batch_data.field_name not in node.__fields__:
raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}")
return values
class Config:
schema_extra = {
"required": [
"graph",
"runs",
]
}
# endregion Batch
# region Queue Items
DEFAULT_QUEUE_ID = "default"
QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"]
def get_field_values(queue_item_dict: dict) -> Optional[list[NodeFieldValue]]:
field_values_raw = queue_item_dict.get("field_values", None)
return parse_raw_as(list[NodeFieldValue], field_values_raw) if field_values_raw is not None else None
def get_session(queue_item_dict: dict) -> GraphExecutionState:
session_raw = queue_item_dict.get("session", "{}")
return parse_raw_as(GraphExecutionState, session_raw)
class SessionQueueItemWithoutGraph(BaseModel):
"""Session queue item without the full graph. Used for serialization."""
item_id: int = Field(description="The identifier of the session queue item")
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
priority: int = Field(default=0, description="The priority of this queue item")
batch_id: str = Field(description="The ID of the batch associated with this queue item")
session_id: str = Field(
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
)
field_values: Optional[list[NodeFieldValue]] = Field(
default=None, description="The field values that were used for this queue item"
)
queue_id: str = Field(description="The id of the queue with which this item is associated")
error: Optional[str] = Field(default=None, description="The error message if this queue item errored")
created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created")
updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated")
started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started")
completed_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was completed")
@classmethod
def from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
# must parse these manually
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
return SessionQueueItemDTO(**queue_item_dict)
class Config:
schema_extra = {
"required": [
"item_id",
"status",
"batch_id",
"queue_id",
"session_id",
"priority",
"session_id",
"created_at",
"updated_at",
]
}
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
pass
class SessionQueueItem(SessionQueueItemWithoutGraph):
session: GraphExecutionState = Field(description="The fully-populated session to be executed")
@classmethod
def from_dict(cls, queue_item_dict: dict) -> "SessionQueueItem":
# must parse these manually
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
queue_item_dict["session"] = get_session(queue_item_dict)
return SessionQueueItem(**queue_item_dict)
class Config:
schema_extra = {
"required": [
"item_id",
"status",
"batch_id",
"queue_id",
"session_id",
"session",
"priority",
"session_id",
"created_at",
"updated_at",
]
}
# endregion Queue Items
# region Query Results
class SessionQueueStatus(BaseModel):
queue_id: str = Field(..., description="The ID of the queue")
item_id: Optional[int] = Field(description="The current queue item id")
batch_id: Optional[str] = Field(description="The current queue item's batch id")
session_id: Optional[str] = Field(description="The current queue item's session id")
pending: int = Field(..., description="Number of queue items with status 'pending'")
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
completed: int = Field(..., description="Number of queue items with status 'complete'")
failed: int = Field(..., description="Number of queue items with status 'error'")
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
total: int = Field(..., description="Total number of queue items")
class BatchStatus(BaseModel):
queue_id: str = Field(..., description="The ID of the queue")
batch_id: str = Field(..., description="The ID of the batch")
pending: int = Field(..., description="Number of queue items with status 'pending'")
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
completed: int = Field(..., description="Number of queue items with status 'complete'")
failed: int = Field(..., description="Number of queue items with status 'error'")
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
total: int = Field(..., description="Total number of queue items")
class EnqueueBatchResult(BaseModel):
queue_id: str = Field(description="The ID of the queue")
enqueued: int = Field(description="The total number of queue items enqueued")
requested: int = Field(description="The total number of queue items requested to be enqueued")
batch: Batch = Field(description="The batch that was enqueued")
priority: int = Field(description="The priority of the enqueued batch")
class EnqueueGraphResult(BaseModel):
enqueued: int = Field(description="The total number of queue items enqueued")
requested: int = Field(description="The total number of queue items requested to be enqueued")
batch: Batch = Field(description="The batch that was enqueued")
priority: int = Field(description="The priority of the enqueued batch")
queue_item: SessionQueueItemDTO = Field(description="The queue item that was enqueued")
class ClearResult(BaseModel):
"""Result of clearing the session queue"""
deleted: int = Field(..., description="Number of queue items deleted")
class PruneResult(ClearResult):
"""Result of pruning the session queue"""
pass
class CancelByBatchIDsResult(BaseModel):
"""Result of canceling by list of batch ids"""
canceled: int = Field(..., description="Number of queue items canceled")
class CancelByQueueIDResult(CancelByBatchIDsResult):
"""Result of canceling by queue id"""
pass
class IsEmptyResult(BaseModel):
"""Result of checking if the session queue is empty"""
is_empty: bool = Field(..., description="Whether the session queue is empty")
class IsFullResult(BaseModel):
"""Result of checking if the session queue is full"""
is_full: bool = Field(..., description="Whether the session queue is full")
# endregion Query Results
# region Util
def populate_graph(graph: Graph, node_field_values: Iterable[NodeFieldValue]) -> Graph:
"""
Populates the given graph with the given batch data items.
"""
graph_clone = graph.copy(deep=True)
for item in node_field_values:
node = graph_clone.get_node(item.node_path)
if node is None:
continue
setattr(node, item.field_name, item.value)
graph_clone.update_node(item.node_path, node)
return graph_clone
def create_session_nfv_tuples(
batch: Batch, maximum: int
) -> Generator[tuple[GraphExecutionState, list[NodeFieldValue]], None, None]:
"""
Create all graph permutations from the given batch data and graph. Yields tuples
of the form (graph, batch_data_items) where batch_data_items is the list of BatchDataItems
that was applied to the graph.
"""
# TODO: Should this be a class method on Batch?
data: list[list[tuple[NodeFieldValue]]] = []
batch_data_collection = batch.data if batch.data is not None else []
for batch_datum_list in batch_data_collection:
# each batch_datum_list needs to be convered to NodeFieldValues and then zipped
node_field_values_to_zip: list[list[NodeFieldValue]] = []
for batch_datum in batch_datum_list:
node_field_values = [
NodeFieldValue(node_path=batch_datum.node_path, field_name=batch_datum.field_name, value=item)
for item in batch_datum.items
]
node_field_values_to_zip.append(node_field_values)
data.append(list(zip(*node_field_values_to_zip)))
# create generator to yield session,nfv tuples
count = 0
for _ in range(batch.runs):
for d in product(*data):
if count >= maximum:
return
flat_node_field_values = list(chain.from_iterable(d))
graph = populate_graph(batch.graph, flat_node_field_values)
yield (GraphExecutionState(graph=graph), flat_node_field_values)
count += 1
def calc_session_count(batch: Batch) -> int:
"""
Calculates the number of sessions that would be created by the batch, without incurring
the overhead of actually generating them. Adapted from `create_sessions().
"""
# TODO: Should this be a class method on Batch?
if not batch.data:
return batch.runs
data = []
for batch_datum_list in batch.data:
to_zip = []
for batch_datum in batch_datum_list:
batch_data_items = range(len(batch_datum.items))
to_zip.append(batch_data_items)
data.append(list(zip(*to_zip)))
data_product = list(product(*data))
return len(data_product) * batch.runs
class SessionQueueValueToInsert(NamedTuple):
"""A tuple of values to insert into the session_queue table"""
queue_id: str # queue_id
session: str # session json
session_id: str # session_id
batch_id: str # batch_id
field_values: Optional[str] # field_values json
priority: int # priority
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new_queue_items: int) -> ValuesToInsert:
values_to_insert: ValuesToInsert = []
for session, field_values in create_session_nfv_tuples(batch, max_new_queue_items):
# sessions must have unique id
session.id = uuid_string()
values_to_insert.append(
SessionQueueValueToInsert(
queue_id, # queue_id
session.json(), # session (json)
session.id, # session_id
batch.batch_id, # batch_id
# must use pydantic_encoder bc field_values is a list of models
json.dumps(field_values, default=pydantic_encoder) if field_values else None, # field_values (json)
priority, # priority
)
)
return values_to_insert
# endregion Util

View File

@ -0,0 +1,816 @@
import sqlite3
import threading
from typing import Optional, Union, cast
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import Graph
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
from invokeai.app.services.session_queue.session_queue_common import (
DEFAULT_QUEUE_ID,
QUEUE_ITEM_STATUS,
Batch,
BatchStatus,
CancelByBatchIDsResult,
CancelByQueueIDResult,
ClearResult,
EnqueueBatchResult,
EnqueueGraphResult,
IsEmptyResult,
IsFullResult,
PruneResult,
SessionQueueItem,
SessionQueueItemDTO,
SessionQueueItemNotFoundError,
SessionQueueStatus,
calc_session_count,
prepare_values_to_insert,
)
from invokeai.app.services.shared.models import CursorPaginatedResults
class SqliteSessionQueue(SessionQueueBase):
__invoker: Invoker
__conn: sqlite3.Connection
__cursor: sqlite3.Cursor
__lock: threading.Lock
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
self._set_in_progress_to_canceled()
prune_result = self.prune(DEFAULT_QUEUE_ID)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
super().__init__()
self.__conn = conn
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self.__conn.row_factory = sqlite3.Row
self.__cursor = self.__conn.cursor()
self.__lock = lock
self._create_tables()
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
return event[1]["event"] in match_in
async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent:
event_name = event[1]["event"]
match event_name:
case "graph_execution_state_complete":
await self._handle_complete_event(event)
case "invocation_error" | "session_retrieval_error" | "invocation_retrieval_error":
await self._handle_error_event(event)
case "session_canceled":
await self._handle_cancel_event(event)
return event
async def _handle_complete_event(self, event: FastAPIEvent) -> None:
try:
item_id = event[1]["data"]["queue_item_id"]
# When a queue item has an error, we get an error event, then a completed event.
# Mark the queue item completed only if it isn't already marked completed, e.g.
# by a previously-handled error event.
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["completed", "failed", "canceled"]:
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed")
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
except SessionQueueItemNotFoundError:
return
async def _handle_error_event(self, event: FastAPIEvent) -> None:
try:
item_id = event[1]["data"]["queue_item_id"]
error = event[1]["data"]["error"]
queue_item = self.get_queue_item(item_id)
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error)
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
except SessionQueueItemNotFoundError:
return
async def _handle_cancel_event(self, event: FastAPIEvent) -> None:
try:
item_id = event[1]["data"]["queue_item_id"]
queue_item = self.get_queue_item(item_id)
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled")
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
except SessionQueueItemNotFoundError:
return
def _create_tables(self) -> None:
"""Creates the session queue tables, indicies, and triggers"""
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS session_queue (
item_id INTEGER PRIMARY KEY AUTOINCREMENT, -- used for ordering, cursor pagination
batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to
queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to
session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access
field_values TEXT, -- NULL if no values are associated with this queue item
session TEXT NOT NULL, -- the session to be executed
status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'completed', 'failed', 'canceled'
priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important
error TEXT, -- any errors associated with this queue item
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger
started_at DATETIME, -- updated via trigger
completed_at DATETIME -- updated via trigger, completed items are cleaned up on application startup
-- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE...
-- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE
);
"""
)
self.__cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);
"""
)
self.__cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);
"""
)
self.__cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);
"""
)
self.__cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);
"""
)
self.__cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);
"""
)
self.__cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at
AFTER UPDATE OF status ON session_queue
FOR EACH ROW
WHEN
NEW.status = 'completed'
OR NEW.status = 'failed'
OR NEW.status = 'canceled'
BEGIN
UPDATE session_queue
SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = NEW.item_id;
END;
"""
)
self.__cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_started_at
AFTER UPDATE OF status ON session_queue
FOR EACH ROW
WHEN
NEW.status = 'in_progress'
BEGIN
UPDATE session_queue
SET started_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = NEW.item_id;
END;
"""
)
self.__cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at
AFTER UPDATE
ON session_queue FOR EACH ROW
BEGIN
UPDATE session_queue
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = old.item_id;
END;
"""
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
def _set_in_progress_to_canceled(self) -> None:
"""
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
This is necessary because the invoker may have been killed while processing a queue item.
"""
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
UPDATE session_queue
SET status = 'canceled'
WHERE status = 'in_progress';
"""
)
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
def _get_current_queue_size(self, queue_id: str) -> int:
"""Gets the current number of pending queue items"""
self.__cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
return cast(int, self.__cursor.fetchone()[0])
def _get_highest_priority(self, queue_id: str) -> int:
"""Gets the highest priority value in the queue"""
self.__cursor.execute(
"""--sql
SELECT MAX(priority)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
return cast(Union[int, None], self.__cursor.fetchone()[0]) or 0
def enqueue_graph(self, queue_id: str, graph: Graph, prepend: bool) -> EnqueueGraphResult:
enqueue_result = self.enqueue_batch(queue_id=queue_id, batch=Batch(graph=graph), prepend=prepend)
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
AND batch_id = ?
""",
(queue_id, enqueue_result.batch.batch_id),
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
if result is None:
raise SessionQueueItemNotFoundError(f"No queue item with batch id {enqueue_result.batch.batch_id}")
return EnqueueGraphResult(
**enqueue_result.dict(),
queue_item=SessionQueueItemDTO.from_dict(dict(result)),
)
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
try:
self.__lock.acquire()
# TODO: how does this work in a multi-user scenario?
current_queue_size = self._get_current_queue_size(queue_id)
max_queue_size = self.__invoker.services.configuration.get_config().max_queue_size
max_new_queue_items = max_queue_size - current_queue_size
priority = 0
if prepend:
priority = self._get_highest_priority(queue_id) + 1
requested_count = calc_session_count(batch)
values_to_insert = prepare_values_to_insert(
queue_id=queue_id,
batch=batch,
priority=priority,
max_new_queue_items=max_new_queue_items,
)
enqueued_count = len(values_to_insert)
if requested_count > enqueued_count:
values_to_insert = values_to_insert[:max_new_queue_items]
self.__cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority)
VALUES (?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
enqueue_result = EnqueueBatchResult(
queue_id=queue_id,
requested=requested_count,
enqueued=enqueued_count,
batch=batch,
priority=priority,
)
self.__invoker.services.events.emit_batch_enqueued(enqueue_result)
return enqueue_result
def dequeue(self) -> Optional[SessionQueueItem]:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE status = 'pending'
ORDER BY
priority DESC,
item_id ASC
LIMIT 1
"""
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
if result is None:
return None
queue_item = SessionQueueItem.from_dict(dict(result))
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress")
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
return queue_item
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
ORDER BY
priority DESC,
created_at ASC
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
if result is None:
return None
return SessionQueueItem.from_dict(dict(result))
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'in_progress'
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
if result is None:
return None
return SessionQueueItem.from_dict(dict(result))
def _set_queue_item_status(
self, item_id: int, status: QUEUE_ITEM_STATUS, error: Optional[str] = None
) -> SessionQueueItem:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
UPDATE session_queue
SET status = ?, error = ?
WHERE item_id = ?
""",
(status, error, item_id),
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return self.get_queue_item(item_id)
def is_empty(self, queue_id: str) -> IsEmptyResult:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
is_empty = cast(int, self.__cursor.fetchone()[0]) == 0
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return IsEmptyResult(is_empty=is_empty)
def is_full(self, queue_id: str) -> IsFullResult:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
max_queue_size = self.__invoker.services.configuration.max_queue_size
is_full = cast(int, self.__cursor.fetchone()[0]) >= max_queue_size
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return IsFullResult(is_full=is_full)
def delete_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self.get_queue_item(item_id=item_id)
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
DELETE FROM session_queue
WHERE
item_id = ?
""",
(item_id,),
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return queue_item
def clear(self, queue_id: str) -> ClearResult:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT COUNT(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
"""--sql
DELETE
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
self.__invoker.services.events.emit_queue_cleared(queue_id)
return ClearResult(deleted=count)
def prune(self, queue_id: str) -> PruneResult:
try:
where = """--sql
WHERE
queue_id = ?
AND (
status = 'completed'
OR status = 'failed'
OR status = 'canceled'
)
"""
self.__lock.acquire()
self.__cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
{where};
""",
(queue_id,),
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
f"""--sql
DELETE
FROM session_queue
{where};
""",
(queue_id,),
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return PruneResult(deleted=count)
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["canceled", "failed", "completed"]:
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
self.__invoker.services.queue.cancel(queue_item.session_id)
self.__invoker.services.events.emit_session_canceled(
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
queue_batch_id=queue_item.batch_id,
graph_execution_state_id=queue_item.session_id,
)
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
return queue_item
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
try:
current_queue_item = self.get_current(queue_id)
self.__lock.acquire()
placeholders = ", ".join(["?" for _ in batch_ids])
where = f"""--sql
WHERE
queue_id == ?
AND batch_id IN ({placeholders})
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
"""
params = [queue_id] + batch_ids
self.__cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
{where};
""",
tuple(params),
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
{where};
""",
tuple(params),
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self.__invoker.services.queue.cancel(current_queue_item.session_id)
self.__invoker.services.events.emit_session_canceled(
queue_item_id=current_queue_item.item_id,
queue_id=current_queue_item.queue_id,
queue_batch_id=current_queue_item.batch_id,
graph_execution_state_id=current_queue_item.session_id,
)
self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item)
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return CancelByBatchIDsResult(canceled=count)
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
try:
current_queue_item = self.get_current(queue_id)
self.__lock.acquire()
where = """--sql
WHERE
queue_id is ?
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
"""
params = [queue_id]
self.__cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
{where};
""",
tuple(params),
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
{where};
""",
tuple(params),
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
self.__invoker.services.queue.cancel(current_queue_item.session_id)
self.__invoker.services.events.emit_session_canceled(
queue_item_id=current_queue_item.item_id,
queue_id=current_queue_item.queue_id,
queue_batch_id=current_queue_item.batch_id,
graph_execution_state_id=current_queue_item.session_id,
)
self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item)
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return CancelByQueueIDResult(canceled=count)
def get_queue_item(self, item_id: int) -> SessionQueueItem:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT * FROM session_queue
WHERE
item_id = ?
""",
(item_id,),
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
if result is None:
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
return SessionQueueItem.from_dict(dict(result))
def list_queue_items(
self,
queue_id: str,
limit: int,
priority: int,
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
) -> CursorPaginatedResults[SessionQueueItemDTO]:
try:
item_id = cursor
self.__lock.acquire()
query = """--sql
SELECT item_id,
status,
priority,
field_values,
error,
created_at,
updated_at,
completed_at,
started_at,
session_id,
batch_id,
queue_id
FROM session_queue
WHERE queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]
if status is not None:
query += """--sql
AND status = ?
"""
params.append(status)
if item_id is not None:
query += """--sql
AND (priority < ?) OR (priority = ? AND item_id > ?)
"""
params.extend([priority, priority, item_id])
query += """--sql
ORDER BY
priority DESC,
item_id ASC
LIMIT ?
"""
params.append(limit + 1)
self.__cursor.execute(query, params)
results = cast(list[sqlite3.Row], self.__cursor.fetchall())
items = [SessionQueueItemDTO.from_dict(dict(result)) for result in results]
has_more = False
if len(items) > limit:
# remove the extra item
items.pop()
has_more = True
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
GROUP BY status
""",
(queue_id,),
)
counts_result = cast(list[sqlite3.Row], self.__cursor.fetchall())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
current_item = self.get_current(queue_id=queue_id)
total = sum(row[1] for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
return SessionQueueStatus(
queue_id=queue_id,
item_id=current_item.item_id if current_item else None,
session_id=current_item.session_id if current_item else None,
batch_id=current_item.batch_id if current_item else None,
pending=counts.get("pending", 0),
in_progress=counts.get("in_progress", 0),
completed=counts.get("completed", 0),
failed=counts.get("failed", 0),
canceled=counts.get("canceled", 0),
total=total,
)
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE
queue_id = ?
AND batch_id = ?
GROUP BY status
""",
(queue_id, batch_id),
)
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
total = sum(row[1] for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return BatchStatus(
batch_id=batch_id,
queue_id=queue_id,
pending=counts.get("pending", 0),
in_progress=counts.get("in_progress", 0),
completed=counts.get("completed", 0),
failed=counts.get("failed", 0),
canceled=counts.get("canceled", 0),
total=total,
)

View File

@ -0,0 +1,14 @@
from typing import Generic, TypeVar
from pydantic import BaseModel, Field
from pydantic.generics import GenericModel
GenericBaseModel = TypeVar("GenericBaseModel", bound=BaseModel)
class CursorPaginatedResults(GenericModel, Generic[GenericBaseModel]):
"""Cursor-paginated results"""
limit: int = Field(..., description="Limit of items to get")
has_more: bool = Field(..., description="Whether there are more items available")
items: list[GenericBaseModel] = Field(..., description="Items")

View File

@ -1,5 +1,5 @@
import sqlite3
from threading import Lock
import threading
from typing import Generic, Optional, TypeVar, get_args
from pydantic import BaseModel, parse_raw_as
@ -12,23 +12,19 @@ sqlite_memory = ":memory:"
class SqliteItemStorage(ItemStorageABC, Generic[T]):
_filename: str
_table_name: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_id_field: str
_lock: Lock
_lock: threading.Lock
def __init__(self, filename: str, table_name: str, id_field: str = "id"):
def __init__(self, conn: sqlite3.Connection, table_name: str, lock: threading.Lock, id_field: str = "id"):
super().__init__()
self._filename = filename
self._table_name = table_name
self._id_field = id_field # TODO: validate that T has this field
self._lock = Lock()
self._conn = sqlite3.connect(
self._filename, check_same_thread=False
) # TODO: figure out a better threading solution
self._lock = lock
self._conn = conn
self._cursor = self._conn.cursor()
self._create_table()
@ -49,8 +45,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def _parse_item(self, item: str) -> T:
item_type = get_args(self.__orig_class__)[0]
parsed = parse_raw_as(item_type, item)
return parsed
return parse_raw_as(item_type, item)
def set(self, item: T):
try:

View File

@ -0,0 +1,3 @@
import threading
lock = threading.Lock()

View File

@ -1,4 +1,5 @@
import datetime
import uuid
import numpy as np
@ -21,3 +22,8 @@ SEED_MAX = np.iinfo(np.uint32).max
def get_random_seed():
rng = np.random.default_rng(seed=None)
return int(rng.integers(0, SEED_MAX))
def uuid_string():
res = uuid.uuid4()
return str(res)

View File

@ -110,6 +110,9 @@ def stable_diffusion_step_callback(
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
node=node,
source_node_id=source_node_id,