Compare commits

...

87 Commits

Author SHA1 Message Date
732780c376 Playing with eventservice tests 2023-09-08 12:33:12 -04:00
ed7deee8f1 Merge branch 'main' into feat/batch-graphs 2023-09-06 11:03:54 -04:00
d22c4734ee feat(api): move batches to own router 2023-09-06 16:34:49 +10:00
3e7dadd7b3 fix(nodes): add version to iterate and collect 2023-09-05 23:46:20 +10:00
b777dba430 feat: batch events
When a batch creates a session, we need to alert the client of this. Because the sessions are created by the batch manager (not directly in response to a client action), we need to emit an event with the session id.

To accomodate this, a secondary set of sio sub/unsub/event handlers are created. These are specifically for batch events. The room is the `batch_id`.

When creating a batch, the client subscribes to this batch room.

When the batch manager creates a batch session, a `batch_session_created` event is emitted in the appropriate room.  It includes the session id. The client then may subscribe to the session room, and all socket stuff proceeds as it did before.
2023-09-05 21:17:33 +10:00
531c3bb1e2 fix(tests): fix batch tests [WIP] 2023-09-05 18:07:47 +10:00
331743ca0c Merge branch 'main' into feat/batch-graphs 2023-09-05 17:32:42 +10:00
13429e66b3 chore(ui): typegen 2023-09-05 17:32:05 +10:00
2185c85287 feat(tests): add test for batch with subgraph [WIP]
The tests still don't work due to the test events service not emitting events the batch mgr can listen for.
2023-09-05 17:31:55 +10:00
e8a4a654ac feat(batch): use node_path instead of node_id to create batched sessions 2023-09-05 17:30:27 +10:00
26f9ac9f21 Revert "Revert "feat(batches): defer ges *and* batch session creation until execution time""
This reverts commit be971617e3.
2023-09-05 16:34:40 +10:00
8d78af5db7 fix(tests): fix name of test 2023-09-05 16:14:52 +10:00
babd26feab feat(batch): extract repeated logic to function 2023-09-05 16:14:18 +10:00
e9b26e5e7d fix(api): fix duplicate operation id 2023-09-05 16:06:35 +10:00
6b946f53c4 Merge branch 'main' into feat/batch-graphs 2023-09-05 16:02:56 +10:00
70479b9827 Merge branch 'main' into feat/batch-graphs 2023-08-29 10:39:45 -04:00
35099dcdd8 Fix operation id on an endpoint 2023-08-29 10:32:22 -04:00
670600a863 Merge branch 'main' into feat/batch-graphs 2023-08-29 10:22:14 -04:00
6d5403e19d Merge branch 'main' into feat/batch-graphs 2023-08-29 09:07:18 -04:00
0f7695a081 Merge branch 'main' into feat/batch-graphs 2023-08-27 11:41:45 -04:00
d567d9f804 Merge branch 'main' into feat/batch-graphs 2023-08-22 10:12:30 -04:00
68f6140685 fix(tests): fix batches test 2023-08-22 01:48:53 +10:00
be971617e3 Revert "feat(batches): defer ges *and* batch session creation until execution time"
This reverts commit 1652143671.
2023-08-22 01:23:39 +10:00
1652143671 feat(batches): defer ges *and* batch session creation until execution time 2023-08-22 00:54:17 +10:00
88ae19a768 feat(batches): defer ges creation until execution
This improves the overall responsiveness of the system substantially, but does make each iteration *slightly* slower, distributing the up-front cost across the batch.

Two main changes:

1. Create BatchSessions immediately, but do not create a whole graph execution state until the batch is executed.
BatchSessions are created with a `session_id` that does not exist in sessions database.
The default state is changed to `"uninitialized"` to better represent this.

Results: Time to create 5000 batches reduced from over 30s to 2.5s

2. Use `executemany()` to retrieve lists of created sessions.
Results: time to create 5000 batches reduced from 2.5s to under 0.5s

Other changes:

- set BatchSession state to `"in_progress"` just before `invoke()` is called
- rename a few methods to accomodate the new behaviour
- remove unused `BatchProcessStorage.get_created_sessions()` method
2023-08-21 22:22:19 +10:00
50816432dc Merge branch 'main' into feat/batch-graphs 2023-08-21 19:51:41 +10:00
b98c9b516a feat: add batch docstrings 2023-08-21 19:51:16 +10:00
a15a5bc3b8 fix(api): correct get_batch response model 2023-08-21 19:51:02 +10:00
018ff56314 Merge branch 'main' into feat/batch-graphs 2023-08-18 23:32:08 -04:00
137fbacb92 Fix flake8 2023-08-18 15:47:27 -04:00
4b6d9a73ed Merge branch 'main' into feat/batch-graphs 2023-08-18 15:40:34 -04:00
3e26214b83 Add a few more endpoints for managing batches 2023-08-18 15:38:16 -04:00
0282f46c71 Add runs field for running the same batch multiple times 2023-08-18 13:41:07 -04:00
99e03fe92e Run unmodified graph if no batch data is provided 2023-08-18 13:33:09 -04:00
cb65526880 More session not found handling 2023-08-17 14:23:12 -04:00
59bc9ed399 fix(backend): handle BatchSessionNotFoundException in BatchManager._process()
The internal `BatchProcessStorage.get_session()` method throws when it finds nothing, but we were not catching any exceptions.

This caused a exception when the batch manager handles a `graph_execution_state_complete` event that did not originate from a batch.

Fixed by handling the exception.
2023-08-17 13:58:11 +10:00
e62d5478fd fix(backend): fix sqlite cannot commit - no transaction is active
The `commit()` was called even if we hadn't executed anything
2023-08-17 13:55:38 +10:00
2cf0d61b3e Merge branch 'main' into feat/batch-graphs 2023-08-17 13:33:17 +10:00
cc3c2756bd feat(backend): rename batch changes variable
`updateSession` -> `changes`
2023-08-17 13:32:32 +10:00
67cf594bb3 feat(backend): add missing types to batch_manager_storage.py 2023-08-17 13:29:19 +10:00
c5b963f1a6 fix(backend): typo
`relavent` -> `relevant`
2023-08-17 12:47:58 +10:00
4d2dd6bb10 feat(backend): rename BatchManager.process to _process
Just to make it clear that this is not a method on the ABC.
2023-08-17 12:47:05 +10:00
7e4beab4ff feat(backend): surface BatchSessionNodeFoundException
Catch this exception in the router and return an appropriate `HTTPException`.
2023-08-17 12:45:32 +10:00
e16b5f7cdc feat(backend): deserialize batch session directly
If the values from the `session_dict` are invalid, the model instantiation will fail, or if we end up with an invalid `batch_id`, the app will not run. So I think just parsing the dict directly is equivalent.

Also the LSP analyser is pleased now - no red squigglies.
2023-08-17 12:37:03 +10:00
1f355d5810 feat(backend): update batch_manager_storage.py docstrings 2023-08-17 12:31:51 +10:00
df7370f9d9 chore(backend): remove unused code 2023-08-17 12:16:34 +10:00
5bec64d65b fix(backend): fix typings in batch_manager.py
- `batch_indicies` is `tuple[int]` not `list[int]`
- explicit `None` return values
2023-08-17 12:07:20 +10:00
8cf9bd47b2 chore(backend): remove unnecessary batch validation function
The `Batch` model is fully validated by pydantic on instantiation; we do not need any validation logic for it.
2023-08-17 11:59:47 +10:00
c91621b46c fix(backend): BatchProcess.batch_id is required
Providing a `default_factory` is enough for pydantic to know to create the attribute on instantiation if it's not already provided. We can then make make the typing just `str`.
2023-08-17 11:58:29 +10:00
f246b236dd fix(api): fix start_batch route responses 2023-08-17 11:51:14 +10:00
f7277a8b21 Run python black 2023-08-16 15:44:52 -04:00
796ee1246b Add a batch validation test 2023-08-16 15:42:45 -04:00
29fceb960d Fix batch_manager test 2023-08-16 15:33:15 -04:00
796ff34c8a Testing out Spencer's batch data structure 2023-08-16 15:21:11 -04:00
d6a5c2dbe3 Fix tests 2023-08-16 14:35:49 -04:00
ef8dc2e8c5 Merge branch 'main' into feat/batch-graphs 2023-08-16 14:03:34 -04:00
314891a125 Merge branch 'main' into feat/batch-graphs 2023-08-15 22:42:49 -04:00
2d3094f988 Run python black 2023-08-15 21:51:45 -04:00
abf09fc8fa Switch sqlite clients to only use one connection 2023-08-15 21:46:24 -04:00
15e7ca1baa Break apart create/start logic 2023-08-15 16:28:47 -04:00
6cb90e01de Graph is required in batch create 2023-08-15 16:13:51 -04:00
faa4574970 Turn off WAL mode 2023-08-15 15:59:42 -04:00
cc5755d5b1 Update invokeai/app/services/batch_manager_storage.py
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2023-08-15 15:54:57 -04:00
85105fc070 Update invokeai/app/services/batch_manager_storage.py
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2023-08-15 15:54:17 -04:00
ed40aee4c5 Merge branch 'main' into feat/batch-graphs 2023-08-15 15:48:40 -04:00
f8d8b16267 Run python black 2023-08-14 11:01:31 -04:00
846e52f2ea Add test for batch manager 2023-08-14 10:57:18 -04:00
69f541075c Merge branch 'main' into feat/batch-graphs 2023-08-14 10:32:35 -04:00
1debc31e3d Allow cancel of running batch 2023-08-11 15:52:49 -04:00
1d798d4119 Return session id's on batch creation 2023-08-11 11:45:27 -04:00
c1dde83abb Clean up erroniously added lines 2023-08-10 14:28:50 -04:00
280ac15da2 Go back to 1 lock per table 2023-08-10 14:26:22 -04:00
e751f7d815 More testing 2023-08-10 14:09:00 -04:00
e26e4740b3 Testing sqlite issues with batch_manager 2023-08-10 11:38:28 -04:00
835d76af45 Merge branch 'main' into feat/batch-graphs 2023-08-01 16:44:30 -04:00
a3e099bbc0 Instantiate batch managers 2023-08-01 16:44:17 -04:00
a61685696f Run black formatting 2023-08-01 16:41:40 -04:00
02aa93c67c Cancel batch endpoint 2023-07-31 16:05:27 -04:00
55b921818d Create batch manager 2023-07-31 15:45:35 -04:00
bb681a8a11 Merge branch 'main' into feat/batch-graphs 2023-07-31 13:22:11 -04:00
74e0fbce42 Merge branch 'main' into feat/batch-graphs 2023-07-25 22:25:55 -04:00
f080c56771 Testing out generating a new session for each batch_index 2023-07-25 16:50:07 -04:00
d2f968b902 Trying different places of applying batches 2023-07-25 10:23:17 -04:00
e81601acf3 add todo 2023-07-24 18:12:05 -04:00
7073dc0d5d Fix next call in graphexecutionstate.next 2023-07-24 17:45:05 -04:00
d090be60e8 Make batch_indices in graph class more clear 2023-07-24 17:43:49 -04:00
4bad96d9d6 WIP running graphs as batches 2023-07-24 17:41:54 -04:00
20 changed files with 1727 additions and 97 deletions

View File

@ -1,6 +1,7 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from logging import Logger
import sqlite3
from invokeai.app.services.board_image_record_storage import (
SqliteBoardImageRecordStorage,
)
@ -28,6 +29,8 @@ from ..services.invoker import Invoker
from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage
from ..services.model_manager_service import ModelManagerService
from ..services.batch_manager import BatchManager
from ..services.batch_manager_storage import SqliteBatchProcessStorage
from ..services.invocation_stats import InvocationStatsService
from .events import FastAPIEventService
@ -71,18 +74,18 @@ class ApiDependencies:
db_path.parent.mkdir(parents=True, exist_ok=True)
db_location = str(db_path)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
)
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
urls = LocalUrlService()
image_record_storage = SqliteImageRecordStorage(db_location)
image_record_storage = SqliteImageRecordStorage(conn=db_conn)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService()
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
board_record_storage = SqliteBoardRecordStorage(db_location)
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
board_record_storage = SqliteBoardRecordStorage(conn=db_conn)
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn)
boards = BoardService(
services=BoardServiceDependencies(
@ -116,15 +119,19 @@ class ApiDependencies:
)
)
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
batch_manager = BatchManager(batch_manager_storage)
services = InvocationServices(
model_manager=ModelManagerService(config, logger),
events=events,
latents=latents,
images=images,
batch_manager=batch_manager,
boards=boards,
board_images=board_images,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
configuration=config,

View File

@ -0,0 +1,106 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from fastapi import Body, HTTPException, Path, Response
from fastapi.routing import APIRouter
from invokeai.app.services.batch_manager_storage import BatchSession, BatchSessionNotFoundException
# Importing * is bad karma but needed here for node detection
from ...invocations import * # noqa: F401 F403
from ...services.batch_manager import Batch, BatchProcessResponse
from ...services.graph import Graph
from ..dependencies import ApiDependencies
batches_router = APIRouter(prefix="/v1/batches", tags=["sessions"])
@batches_router.post(
"/",
operation_id="create_batch",
responses={
200: {"model": BatchProcessResponse},
400: {"description": "Invalid json"},
},
)
async def create_batch(
graph: Graph = Body(description="The graph to initialize the session with"),
batch: Batch = Body(description="Batch config to apply to the given graph"),
) -> BatchProcessResponse:
"""Creates a batch process"""
return ApiDependencies.invoker.services.batch_manager.create_batch_process(batch, graph)
@batches_router.put(
"/b/{batch_process_id}/invoke",
operation_id="start_batch",
responses={
202: {"description": "Batch process started"},
404: {"description": "Batch session not found"},
},
)
async def start_batch(
batch_process_id: str = Path(description="ID of Batch to start"),
) -> Response:
"""Executes a batch process"""
try:
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_process_id)
return Response(status_code=202)
except BatchSessionNotFoundException:
raise HTTPException(status_code=404, detail="Batch session not found")
@batches_router.delete(
"/b/{batch_process_id}",
operation_id="cancel_batch",
responses={202: {"description": "The batch is canceled"}},
)
async def cancel_batch(
batch_process_id: str = Path(description="The id of the batch process to cancel"),
) -> Response:
"""Cancels a batch process"""
ApiDependencies.invoker.services.batch_manager.cancel_batch_process(batch_process_id)
return Response(status_code=202)
@batches_router.get(
"/incomplete",
operation_id="list_incomplete_batches",
responses={200: {"model": list[BatchProcessResponse]}},
)
async def list_incomplete_batches() -> list[BatchProcessResponse]:
"""Lists incomplete batch processes"""
return ApiDependencies.invoker.services.batch_manager.get_incomplete_batch_processes()
@batches_router.get(
"/",
operation_id="list_batches",
responses={200: {"model": list[BatchProcessResponse]}},
)
async def list_batches() -> list[BatchProcessResponse]:
"""Lists all batch processes"""
return ApiDependencies.invoker.services.batch_manager.get_batch_processes()
@batches_router.get(
"/b/{batch_process_id}",
operation_id="get_batch",
responses={200: {"model": BatchProcessResponse}},
)
async def get_batch(
batch_process_id: str = Path(description="The id of the batch process to get"),
) -> BatchProcessResponse:
"""Gets a Batch Process"""
return ApiDependencies.invoker.services.batch_manager.get_batch(batch_process_id)
@batches_router.get(
"/b/{batch_process_id}/sessions",
operation_id="get_batch_sessions",
responses={200: {"model": list[BatchSession]}},
)
async def get_batch_sessions(
batch_process_id: str = Path(description="The id of the batch process to get"),
) -> list[BatchSession]:
"""Gets a list of batch sessions for a given batch process"""
return ApiDependencies.invoker.services.batch_manager.get_sessions(batch_process_id)

View File

@ -9,13 +9,7 @@ from pydantic.fields import Field
# Importing * is bad karma but needed here for node detection
from ...invocations import * # noqa: F401 F403
from ...invocations.baseinvocation import BaseInvocation
from ...services.graph import (
Edge,
EdgeConnection,
Graph,
GraphExecutionState,
NodeAlreadyExecutedError,
)
from ...services.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
from ...services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies

View File

@ -13,11 +13,15 @@ class SocketIO:
def __init__(self, app: FastAPI):
self.__sio = SocketManager(app=app)
self.__sio.on("subscribe", handler=self._handle_sub)
self.__sio.on("unsubscribe", handler=self._handle_unsub)
self.__sio.on("subscribe_session", handler=self._handle_sub_session)
self.__sio.on("unsubscribe_session", handler=self._handle_unsub_session)
local_handler.register(event_name=EventServiceBase.session_event, _func=self._handle_session_event)
self.__sio.on("subscribe_batch", handler=self._handle_sub_batch)
self.__sio.on("unsubscribe_batch", handler=self._handle_unsub_batch)
local_handler.register(event_name=EventServiceBase.batch_event, _func=self._handle_batch_event)
async def _handle_session_event(self, event: Event):
await self.__sio.emit(
event=event[1]["event"],
@ -25,12 +29,25 @@ class SocketIO:
room=event[1]["data"]["graph_execution_state_id"],
)
async def _handle_sub(self, sid, data, *args, **kwargs):
async def _handle_sub_session(self, sid, data, *args, **kwargs):
if "session" in data:
self.__sio.enter_room(sid, data["session"])
# @app.sio.on('unsubscribe')
async def _handle_unsub(self, sid, data, *args, **kwargs):
async def _handle_unsub_session(self, sid, data, *args, **kwargs):
if "session" in data:
self.__sio.leave_room(sid, data["session"])
async def _handle_batch_event(self, event: Event):
await self.__sio.emit(
event=event[1]["event"],
data=event[1]["data"],
room=event[1]["data"]["batch_id"],
)
async def _handle_sub_batch(self, sid, data, *args, **kwargs):
if "batch_id" in data:
self.__sio.enter_room(sid, data["batch_id"])
async def _handle_unsub_batch(self, sid, data, *args, **kwargs):
if "batch_id" in data:
self.__sio.enter_room(sid, data["batch_id"])

View File

@ -24,7 +24,7 @@ import invokeai.frontend.web as web_dir
import mimetypes
from .api.dependencies import ApiDependencies
from .api.routers import sessions, models, images, boards, board_images, app_info
from .api.routers import sessions, batches, models, images, boards, board_images, app_info
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
@ -90,6 +90,8 @@ async def shutdown_event():
app.include_router(sessions.session_router, prefix="/api")
app.include_router(batches.batches_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")

View File

@ -5,6 +5,7 @@ import re
import shlex
import sys
import time
import sqlite3
from typing import Union, get_type_hints, Optional
from pydantic import BaseModel, ValidationError
@ -29,6 +30,8 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
from invokeai.app.services.batch_manager import BatchManager
from invokeai.app.services.batch_manager_storage import SqliteBatchProcessStorage
from invokeai.app.services.invocation_stats import InvocationStatsService
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
@ -252,19 +255,18 @@ def invoke_cli():
db_location = config.db_path
db_location.parent.mkdir(parents=True, exist_ok=True)
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
logger.info(f'InvokeAI database location is "{db_location}"')
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
urls = LocalUrlService()
image_record_storage = SqliteImageRecordStorage(db_location)
image_record_storage = SqliteImageRecordStorage(conn=db_conn)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService()
board_record_storage = SqliteBoardRecordStorage(db_location)
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
board_record_storage = SqliteBoardRecordStorage(conn=db_conn)
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn)
boards = BoardService(
services=BoardServiceDependencies(
@ -298,15 +300,19 @@ def invoke_cli():
)
)
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
batch_manager = BatchManager(batch_manager_storage)
services = InvocationServices(
model_manager=model_manager,
events=events,
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
images=images,
boards=boards,
batch_manager=batch_manager,
board_images=board_images,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
performance_statistics=InvocationStatsService(graph_execution_manager),

View File

@ -0,0 +1,215 @@
from abc import ABC, abstractmethod
from itertools import product
from typing import Optional
from uuid import uuid4
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event
from pydantic import BaseModel, Field
from invokeai.app.services.batch_manager_storage import (
Batch,
BatchProcess,
BatchProcessStorageBase,
BatchSession,
BatchSessionChanges,
BatchSessionNotFoundException,
)
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import Graph, GraphExecutionState
from invokeai.app.services.invoker import Invoker
class BatchProcessResponse(BaseModel):
batch_id: str = Field(description="ID for the batch")
session_ids: list[str] = Field(description="List of session IDs created for this batch")
class BatchManagerBase(ABC):
@abstractmethod
def start(self, invoker: Invoker) -> None:
"""Starts the BatchManager service"""
pass
@abstractmethod
def create_batch_process(self, batch: Batch, graph: Graph) -> BatchProcessResponse:
"""Creates a batch process"""
pass
@abstractmethod
def run_batch_process(self, batch_id: str) -> None:
"""Runs a batch process"""
pass
@abstractmethod
def cancel_batch_process(self, batch_process_id: str) -> None:
"""Cancels a batch process"""
pass
@abstractmethod
def get_batch(self, batch_id: str) -> BatchProcessResponse:
"""Gets a batch process"""
pass
@abstractmethod
def get_batch_processes(self) -> list[BatchProcessResponse]:
"""Gets all batch processes"""
pass
@abstractmethod
def get_incomplete_batch_processes(self) -> list[BatchProcessResponse]:
"""Gets all incomplete batch processes"""
pass
@abstractmethod
def get_sessions(self, batch_id: str) -> list[BatchSession]:
"""Gets the sessions associated with a batch"""
pass
class BatchManager(BatchManagerBase):
"""Responsible for managing currently running and scheduled batch jobs"""
__invoker: Invoker
__batch_process_storage: BatchProcessStorageBase
def __init__(self, batch_process_storage: BatchProcessStorageBase) -> None:
super().__init__()
self.__batch_process_storage = batch_process_storage
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
local_handler.register(event_name=EventServiceBase.session_event, _func=self.on_event)
async def on_event(self, event: Event):
event_name = event[1]["event"]
match event_name:
case "graph_execution_state_complete":
await self._process(event, False)
case "invocation_error":
await self._process(event, True)
return event
async def _process(self, event: Event, err: bool) -> None:
data = event[1]["data"]
try:
batch_session = self.__batch_process_storage.get_session_by_session_id(data["graph_execution_state_id"])
except BatchSessionNotFoundException:
return None
changes = BatchSessionChanges(state="error" if err else "completed")
batch_session = self.__batch_process_storage.update_session_state(
batch_session.batch_id,
batch_session.session_id,
changes,
)
sessions = self.get_sessions(batch_session.batch_id)
batch_process = self.__batch_process_storage.get(batch_session.batch_id)
if not batch_process.canceled:
self.run_batch_process(batch_process.batch_id)
def _create_graph_execution_state(
self, batch_process: BatchProcess, batch_indices: tuple[int, ...]
) -> GraphExecutionState:
graph = batch_process.graph.copy(deep=True)
batch = batch_process.batch
for index, bdl in enumerate(batch.data):
for bd in bdl:
node = graph.get_node(bd.node_path)
if node is None:
continue
batch_index = batch_indices[index]
datum = bd.items[batch_index]
key = bd.field_name
node.__dict__[key] = datum
graph.update_node(bd.node_path, node)
return GraphExecutionState(graph=graph)
def run_batch_process(self, batch_id: str) -> None:
self.__batch_process_storage.start(batch_id)
batch_process = self.__batch_process_storage.get(batch_id)
next_batch_index = self._get_batch_index_tuple(batch_process)
if next_batch_index is None:
# finished with current run
if batch_process.current_run >= (batch_process.batch.runs - 1):
# finished with all runs
return
batch_process.current_batch_index = 0
batch_process.current_run += 1
next_batch_index = self._get_batch_index_tuple(batch_process)
if next_batch_index is None:
# shouldn't happen; satisfy types
return
# remember to increment the batch index
batch_process.current_batch_index += 1
self.__batch_process_storage.save(batch_process)
ges = self._create_graph_execution_state(batch_process=batch_process, batch_indices=next_batch_index)
next_session = self.__batch_process_storage.create_session(
BatchSession(
batch_id=batch_id,
session_id=str(uuid4()),
state="uninitialized",
batch_index=batch_process.current_batch_index,
)
)
ges.id = next_session.session_id
self.__invoker.services.graph_execution_manager.set(ges)
self.__batch_process_storage.update_session_state(
batch_id=next_session.batch_id,
session_id=next_session.session_id,
changes=BatchSessionChanges(state="in_progress"),
)
self.__invoker.services.events.emit_batch_session_created(next_session.batch_id, next_session.session_id)
self.__invoker.invoke(ges, invoke_all=True)
def create_batch_process(self, batch: Batch, graph: Graph) -> BatchProcessResponse:
batch_process = BatchProcess(
batch=batch,
graph=graph,
)
batch_process = self.__batch_process_storage.save(batch_process)
return BatchProcessResponse(
batch_id=batch_process.batch_id,
session_ids=[],
)
def get_sessions(self, batch_id: str) -> list[BatchSession]:
return self.__batch_process_storage.get_sessions_by_batch_id(batch_id)
def get_batch(self, batch_id: str) -> BatchProcess:
return self.__batch_process_storage.get(batch_id)
def get_batch_processes(self) -> list[BatchProcessResponse]:
bps = self.__batch_process_storage.get_all()
return self._get_batch_process_responses(bps)
def get_incomplete_batch_processes(self) -> list[BatchProcessResponse]:
bps = self.__batch_process_storage.get_incomplete()
return self._get_batch_process_responses(bps)
def cancel_batch_process(self, batch_process_id: str) -> None:
self.__batch_process_storage.cancel(batch_process_id)
def _get_batch_process_responses(self, batch_processes: list[BatchProcess]) -> list[BatchProcessResponse]:
sessions = list()
res: list[BatchProcessResponse] = list()
for bp in batch_processes:
sessions = self.__batch_process_storage.get_sessions_by_batch_id(bp.batch_id)
res.append(
BatchProcessResponse(
batch_id=bp.batch_id,
session_ids=[session.session_id for session in sessions],
)
)
return res
def _get_batch_index_tuple(self, batch_process: BatchProcess) -> Optional[tuple[int, ...]]:
batch_indices = list()
for batchdata in batch_process.batch.data:
batch_indices.append(list(range(len(batchdata[0].items))))
try:
return list(product(*batch_indices))[batch_process.current_batch_index]
except IndexError:
return None

View File

@ -0,0 +1,707 @@
import sqlite3
import threading
import uuid
from abc import ABC, abstractmethod
from typing import List, Literal, Optional, Union, cast
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr, parse_raw_as, validator
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.services.graph import Graph
BatchDataType = Union[StrictStr, StrictInt, StrictFloat, ImageField]
class BatchData(BaseModel):
"""
A batch data collection.
"""
node_path: str = Field(description="The node into which this batch data collection will be substituted.")
field_name: str = Field(description="The field into which this batch data collection will be substituted.")
items: list[BatchDataType] = Field(
default_factory=list, description="The list of items to substitute into the node/field."
)
class Batch(BaseModel):
"""
A batch, consisting of a list of a list of batch data collections.
First, each inner list[BatchData] is zipped into a single batch data collection.
Then, the final batch collection is created by taking the Cartesian product of all batch data collections.
"""
data: list[list[BatchData]] = Field(default_factory=list, description="The list of batch data collections.")
runs: int = Field(default=1, description="Int stating how many times to iterate through all possible batch indices")
@validator("runs")
def validate_positive_runs(cls, r: int):
if r < 1:
raise ValueError("runs must be a positive integer")
return r
@validator("data")
def validate_len(cls, v: list[list[BatchData]]):
for batch_data in v:
if any(len(batch_data[0].items) != len(i.items) for i in batch_data):
raise ValueError("Zipped batch items must have all have same length")
return v
@validator("data")
def validate_types(cls, v: list[list[BatchData]]):
for batch_data in v:
for datum in batch_data:
for item in datum.items:
if not all(isinstance(item, type(i)) for i in datum.items):
raise TypeError("All items in a batch must have have same type")
return v
@validator("data")
def validate_unique_field_mappings(cls, v: list[list[BatchData]]):
paths: set[tuple[str, str]] = set()
count: int = 0
for batch_data in v:
for datum in batch_data:
paths.add((datum.node_path, datum.field_name))
count += 1
if len(paths) != count:
raise ValueError("Each batch data must have unique node_id and field_name")
return v
def uuid_string():
res = uuid.uuid4()
return str(res)
BATCH_SESSION_STATE = Literal["uninitialized", "in_progress", "completed", "error"]
class BatchSession(BaseModel):
batch_id: str = Field(defaultdescription="The Batch to which this BatchSession is attached.")
session_id: str = Field(
default_factory=uuid_string, description="The Session to which this BatchSession is attached."
)
batch_index: int = Field(description="The index of this batch session in its parent batch process")
state: BATCH_SESSION_STATE = Field(default="uninitialized", description="The state of this BatchSession")
class BatchProcess(BaseModel):
batch_id: str = Field(default_factory=uuid_string, description="Identifier for this batch.")
batch: Batch = Field(description="The Batch to apply to this session.")
current_batch_index: int = Field(default=0, description="The last executed batch index")
current_run: int = Field(default=0, description="The current run of the batch")
canceled: bool = Field(description="Whether or not to run sessions from this batch.", default=False)
graph: Graph = Field(description="The graph into which batch data will be inserted before being executed.")
class BatchSessionChanges(BaseModel, extra=Extra.forbid):
state: BATCH_SESSION_STATE = Field(description="The state of this BatchSession")
class BatchProcessNotFoundException(Exception):
"""Raised when an Batch Process record is not found."""
def __init__(self, message="BatchProcess record not found"):
super().__init__(message)
class BatchProcessSaveException(Exception):
"""Raised when an Batch Process record cannot be saved."""
def __init__(self, message="BatchProcess record not saved"):
super().__init__(message)
class BatchProcessDeleteException(Exception):
"""Raised when an Batch Process record cannot be deleted."""
def __init__(self, message="BatchProcess record not deleted"):
super().__init__(message)
class BatchSessionNotFoundException(Exception):
"""Raised when an Batch Session record is not found."""
def __init__(self, message="BatchSession record not found"):
super().__init__(message)
class BatchSessionSaveException(Exception):
"""Raised when an Batch Session record cannot be saved."""
def __init__(self, message="BatchSession record not saved"):
super().__init__(message)
class BatchSessionDeleteException(Exception):
"""Raised when an Batch Session record cannot be deleted."""
def __init__(self, message="BatchSession record not deleted"):
super().__init__(message)
class BatchProcessStorageBase(ABC):
"""Low-level service responsible for interfacing with the Batch Process record store."""
@abstractmethod
def delete(self, batch_id: str) -> None:
"""Deletes a BatchProcess record."""
pass
@abstractmethod
def save(
self,
batch_process: BatchProcess,
) -> BatchProcess:
"""Saves a BatchProcess record."""
pass
@abstractmethod
def get(
self,
batch_id: str,
) -> BatchProcess:
"""Gets a BatchProcess record."""
pass
@abstractmethod
def get_all(
self,
) -> list[BatchProcess]:
"""Gets a BatchProcess record."""
pass
@abstractmethod
def get_incomplete(
self,
) -> list[BatchProcess]:
"""Gets a BatchProcess record."""
pass
@abstractmethod
def start(
self,
batch_id: str,
) -> None:
"""'Starts' a BatchProcess record by marking its `canceled` attribute to False."""
pass
@abstractmethod
def cancel(
self,
batch_id: str,
) -> None:
"""'Cancels' a BatchProcess record by setting its `canceled` attribute to True."""
pass
@abstractmethod
def create_session(
self,
session: BatchSession,
) -> BatchSession:
"""Creates a BatchSession attached to a BatchProcess."""
pass
@abstractmethod
def create_sessions(
self,
sessions: list[BatchSession],
) -> list[BatchSession]:
"""Creates many BatchSessions attached to a BatchProcess."""
pass
@abstractmethod
def get_session_by_session_id(self, session_id: str) -> BatchSession:
"""Gets a BatchSession by session_id"""
pass
@abstractmethod
def get_sessions_by_batch_id(self, batch_id: str) -> List[BatchSession]:
"""Gets all BatchSession's for a given BatchProcess id."""
pass
@abstractmethod
def get_sessions_by_session_ids(self, session_ids: list[str]) -> List[BatchSession]:
"""Gets all BatchSession's for a given list of session ids."""
pass
@abstractmethod
def get_next_session(self, batch_id: str) -> BatchSession:
"""Gets the next BatchSession with state `uninitialized`, for a given BatchProcess id."""
pass
@abstractmethod
def update_session_state(
self,
batch_id: str,
session_id: str,
changes: BatchSessionChanges,
) -> BatchSession:
"""Updates the state of a BatchSession record."""
pass
class SqliteBatchProcessStorage(BatchProcessStorageBase):
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, conn: sqlite3.Connection) -> None:
super().__init__()
self._conn = conn
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `batch_process` table and `batch_session` junction table."""
# Create the `batch_process` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS batch_process (
batch_id TEXT NOT NULL PRIMARY KEY,
batch TEXT NOT NULL,
graph TEXT NOT NULL,
current_batch_index NUMBER NOT NULL,
current_run NUMBER NOT NULL,
canceled BOOLEAN NOT NULL DEFAULT(0),
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME
);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_batch_process_created_at ON batch_process (created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_batch_process_updated_at
AFTER UPDATE
ON batch_process FOR EACH ROW
BEGIN
UPDATE batch_process SET updated_at = current_timestamp
WHERE batch_id = old.batch_id;
END;
"""
)
# Create the `batch_session` junction table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS batch_session (
batch_id TEXT NOT NULL,
session_id TEXT NOT NULL,
state TEXT NOT NULL,
batch_index NUMBER NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
-- enforce one-to-many relationship between batch_process and batch_session using PK
-- (we can extend this to many-to-many later)
PRIMARY KEY (batch_id,session_id),
FOREIGN KEY (batch_id) REFERENCES batch_process (batch_id) ON DELETE CASCADE
);
"""
)
# Add index for batch id
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id ON batch_session (batch_id);
"""
)
# Add index for batch id, sorted by created_at
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id_created_at ON batch_session (batch_id,created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_batch_session_updated_at
AFTER UPDATE
ON batch_session FOR EACH ROW
BEGIN
UPDATE batch_session SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE batch_id = old.batch_id AND session_id = old.session_id;
END;
"""
)
def delete(self, batch_id: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM batch_process
WHERE batch_id = ?;
""",
(batch_id,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchProcessDeleteException from e
except Exception as e:
self._conn.rollback()
raise BatchProcessDeleteException from e
finally:
self._lock.release()
def save(
self,
batch_process: BatchProcess,
) -> BatchProcess:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR REPLACE INTO batch_process (batch_id, batch, graph, current_batch_index, current_run)
VALUES (?, ?, ?, ?, ?);
""",
(
batch_process.batch_id,
batch_process.batch.json(),
batch_process.graph.json(),
batch_process.current_batch_index,
batch_process.current_run,
),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchProcessSaveException from e
finally:
self._lock.release()
return self.get(batch_process.batch_id)
def _deserialize_batch_process(self, session_dict: dict) -> BatchProcess:
"""Deserializes a batch session."""
# Retrieve all the values, setting "reasonable" defaults if they are not present.
batch_id = session_dict.get("batch_id", "unknown")
batch_raw = session_dict.get("batch", "unknown")
graph_raw = session_dict.get("graph", "unknown")
current_batch_index = session_dict.get("current_batch_index", 0)
current_run = session_dict.get("current_run", 0)
canceled = session_dict.get("canceled", 0)
return BatchProcess(
batch_id=batch_id,
batch=parse_raw_as(Batch, batch_raw),
graph=parse_raw_as(Graph, graph_raw),
current_batch_index=current_batch_index,
current_run=current_run,
canceled=canceled == 1,
)
def get(
self,
batch_id: str,
) -> BatchProcess:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_process
WHERE batch_id = ?;
""",
(batch_id,),
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchProcessNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchProcessNotFoundException
return self._deserialize_batch_process(dict(result))
def get_all(
self,
) -> list[BatchProcess]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_process
"""
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchProcessNotFoundException from e
finally:
self._lock.release()
if result is None:
return list()
return list(map(lambda r: self._deserialize_batch_process(dict(r)), result))
def get_incomplete(
self,
) -> list[BatchProcess]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT bp.*
FROM batch_process bp
WHERE bp.batch_id IN
(
SELECT batch_id
FROM batch_session bs
WHERE state IN ('uninitialized', 'in_progress')
);
"""
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchProcessNotFoundException from e
finally:
self._lock.release()
if result is None:
return list()
return list(map(lambda r: self._deserialize_batch_process(dict(r)), result))
def start(
self,
batch_id: str,
) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
UPDATE batch_process
SET canceled = 0
WHERE batch_id = ?;
""",
(batch_id,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionSaveException from e
finally:
self._lock.release()
def cancel(
self,
batch_id: str,
) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
UPDATE batch_process
SET canceled = 1
WHERE batch_id = ?;
""",
(batch_id,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionSaveException from e
finally:
self._lock.release()
def create_session(
self,
session: BatchSession,
) -> BatchSession:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
VALUES (?, ?, ?, ?);
""",
(session.batch_id, session.session_id, session.state, session.batch_index),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionSaveException from e
finally:
self._lock.release()
return self.get_session_by_session_id(session.session_id)
def create_sessions(
self,
sessions: list[BatchSession],
) -> list[BatchSession]:
try:
self._lock.acquire()
session_data = [(session.batch_id, session.session_id, session.state) for session in sessions]
self._cursor.executemany(
"""--sql
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
VALUES (?, ?, ?, ?);
""",
session_data,
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionSaveException from e
finally:
self._lock.release()
return self.get_sessions_by_session_ids([session.session_id for session in sessions])
def get_session_by_session_id(self, session_id: str) -> BatchSession:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_session
WHERE session_id= ?;
""",
(session_id,),
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchSessionNotFoundException
return self._deserialize_batch_session(dict(result))
def _deserialize_batch_session(self, session_dict: dict) -> BatchSession:
"""Deserializes a batch session."""
return BatchSession.parse_obj(session_dict)
def get_next_session(self, batch_id: str) -> BatchSession:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_session
WHERE batch_id = ? AND state = 'uninitialized';
""",
(batch_id,),
)
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchSessionNotFoundException
session = self._deserialize_batch_session(dict(result))
return session
def get_sessions_by_batch_id(self, batch_id: str) -> List[BatchSession]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_session
WHERE batch_id = ?;
""",
(batch_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchSessionNotFoundException
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
return sessions
def get_sessions_by_session_ids(self, session_ids: list[str]) -> List[BatchSession]:
try:
self._lock.acquire()
placeholders = ",".join("?" * len(session_ids))
self._cursor.execute(
f"""--sql
SELECT * FROM batch_session
WHERE session_id
IN ({placeholders})
""",
tuple(session_ids),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchSessionNotFoundException
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
return sessions
def update_session_state(
self,
batch_id: str,
session_id: str,
changes: BatchSessionChanges,
) -> BatchSession:
try:
self._lock.acquire()
# Change the state of a batch session
if changes.state is not None:
self._cursor.execute(
"""--sql
UPDATE batch_session
SET state = ?
WHERE batch_id = ? AND session_id = ?;
""",
(changes.state, batch_id, session_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionSaveException from e
finally:
self._lock.release()
return self.get_session_by_session_id(session_id)

View File

@ -56,15 +56,13 @@ class BoardImageRecordStorageBase(ABC):
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
def __init__(self, conn: sqlite3.Connection) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
self._conn = conn
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()

View File

@ -89,15 +89,13 @@ class BoardRecordStorageBase(ABC):
class SqliteBoardRecordStorage(BoardRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
def __init__(self, conn: sqlite3.Connection) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
self._conn = conn
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()

View File

@ -13,6 +13,7 @@ from invokeai.app.services.model_manager_service import (
class EventServiceBase:
session_event: str = "session_event"
batch_event: str = "batch_event"
"""Basic event bus, to have an empty stand-in when not needed"""
@ -20,12 +21,21 @@ class EventServiceBase:
pass
def __emit_session_event(self, event_name: str, payload: dict) -> None:
"""Session events are emitted to a room with the session_id as the room name"""
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.session_event,
payload=dict(event=event_name, data=payload),
)
def __emit_batch_event(self, event_name: str, payload: dict) -> None:
"""Batch events are emitted to a room with the batch_id as the room name"""
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.batch_event,
payload=dict(event=event_name, data=payload),
)
# Define events here for every event in the system.
# This will make them easier to integrate until we find a schema generator.
def emit_generator_progress(
@ -187,3 +197,14 @@ class EventServiceBase:
error=error,
),
)
def emit_batch_session_created(
self,
batch_id: str,
graph_execution_state_id: str,
) -> None:
"""Emitted when a batch session is created"""
self.__emit_batch_event(
event_name="batch_session_created",
payload=dict(batch_id=batch_id, graph_execution_state_id=graph_execution_state_id),
)

View File

@ -152,15 +152,13 @@ class ImageRecordStorageBase(ABC):
class SqliteImageRecordStorage(ImageRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
def __init__(self, conn: sqlite3.Connection) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
self._conn = conn
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()

View File

@ -4,6 +4,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from logging import Logger
from invokeai.app.services.batch_manager import BatchManagerBase
from invokeai.app.services.board_images import BoardImagesServiceABC
from invokeai.app.services.boards import BoardServiceABC
from invokeai.app.services.images import ImageServiceABC
@ -22,6 +23,7 @@ class InvocationServices:
"""Services that can be used by invocations"""
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
batch_manager: "BatchManagerBase"
board_images: "BoardImagesServiceABC"
boards: "BoardServiceABC"
configuration: "InvokeAIAppConfig"
@ -38,6 +40,7 @@ class InvocationServices:
def __init__(
self,
batch_manager: "BatchManagerBase",
board_images: "BoardImagesServiceABC",
boards: "BoardServiceABC",
configuration: "InvokeAIAppConfig",
@ -52,6 +55,7 @@ class InvocationServices:
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC",
):
self.batch_manager = batch_manager
self.board_images = board_images
self.boards = boards
self.boards = boards

View File

@ -12,23 +12,19 @@ sqlite_memory = ":memory:"
class SqliteItemStorage(ItemStorageABC, Generic[T]):
_filename: str
_table_name: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_id_field: str
_lock: Lock
def __init__(self, filename: str, table_name: str, id_field: str = "id"):
def __init__(self, conn: sqlite3.Connection, table_name: str, id_field: str = "id"):
super().__init__()
self._filename = filename
self._table_name = table_name
self._id_field = id_field # TODO: validate that T has this field
self._lock = Lock()
self._conn = sqlite3.connect(
self._filename, check_same_thread=False
) # TODO: figure out a better threading solution
self._conn = conn
self._cursor = self._conn.cursor()
self._create_table()
@ -49,8 +45,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def _parse_item(self, item: str) -> T:
item_type = get_args(self.__orig_class__)[0]
parsed = parse_raw_as(item_type, item)
return parsed
return parse_raw_as(item_type, item)
def set(self, item: T):
try:

File diff suppressed because one or more lines are too long

View File

@ -102,6 +102,7 @@ dependencies = [
"flake8",
"Flake8-pyproject",
"pytest>6.0.0",
"pytest-asyncio",
"pytest-cov",
"pytest-datadir",
]
@ -176,6 +177,7 @@ version = { attr = "invokeai.version.__version__" }
#=== Begin: PyTest and Coverage
[tool.pytest.ini_options]
addopts = "--cov-report term --cov-report html --cov-report xml"
asyncio_mode = "auto"
[tool.coverage.run]
branch = true
source = ["invokeai"]

View File

@ -25,6 +25,7 @@ from invokeai.app.services.graph import (
LibraryGraph,
)
import pytest
import sqlite3
@pytest.fixture
@ -42,9 +43,8 @@ def simple_graph():
@pytest.fixture
def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=sqlite_memory, table_name="graph_executions"
)
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
return InvocationServices(
model_manager=None, # type: ignore
events=TestEventService(),
@ -52,9 +52,10 @@ def mock_services() -> InvocationServices:
images=None, # type: ignore
latents=None, # type: ignore
boards=None, # type: ignore
batch_manager=None, # type: ignore
board_images=None, # type: ignore
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"),
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
graph_execution_manager=graph_execution_manager,
performance_statistics=InvocationStatsService(graph_execution_manager),
processor=DefaultInvocationProcessor(),

View File

@ -6,18 +6,34 @@ from .test_nodes import (
create_edge,
wait_until,
)
# from fastapi_events.handlers.local import
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
from invokeai.app.services.processor import DefaultInvocationProcessor
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats import InvocationStatsService
from invokeai.app.api.events import FastAPIEventService
from invokeai.app.services.batch_manager_storage import BatchData, SqliteBatchProcessStorage
from invokeai.app.services.batch_manager import (
Batch,
BatchManager,
)
from invokeai.app.services.graph import (
Graph,
GraphExecutionState,
GraphInvocation,
LibraryGraph,
)
import pytest
import pytest_asyncio
import sqlite3
import time
from httpx import AsyncClient
from fastapi import FastAPI
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
@pytest.fixture
@ -29,25 +45,128 @@ def simple_graph():
return g
@pytest.fixture
def simple_batch():
return Batch(
data=[
[
BatchData(
node_path="1",
field_name="prompt",
items=[
"Tomato sushi",
"Strawberry sushi",
"Broccoli sushi",
"Asparagus sushi",
"Tea sushi",
],
)
],
[
BatchData(
node_path="2",
field_name="prompt",
items=[
"Ume sushi",
"Ichigo sushi",
"Momo sushi",
"Mikan sushi",
"Cha sushi",
],
)
],
]
)
@pytest.fixture
def graph_with_subgraph():
sub_g = Graph()
sub_g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
sub_g.add_node(TextToImageTestInvocation(id="2"))
sub_g.add_edge(create_edge("1", "prompt", "2", "prompt"))
g = Graph()
g.add_node(GraphInvocation(id="1", graph=sub_g))
return g
@pytest.fixture
def batch_with_subgraph():
return Batch(
data=[
[
BatchData(
node_path="1.1",
field_name="prompt",
items=[
"Tomato sushi",
"Strawberry sushi",
"Broccoli sushi",
"Asparagus sushi",
"Tea sushi",
],
)
],
[
BatchData(
node_path="1.2",
field_name="prompt",
items=[
"Ume sushi",
"Ichigo sushi",
"Momo sushi",
"Mikan sushi",
"Cha sushi",
],
)
],
]
)
# @pytest_asyncio.fixture(scope="module")
# def event_loop():
# import asyncio
# try:
# loop = asyncio.get_running_loop()
# except RuntimeError as e:
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
# # fastapi_events.event_store
# yield loop
# loop.close()
@pytest.fixture(scope="session")
def db_conn():
return sqlite3.connect(sqlite_memory, check_same_thread=False)
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
# the test invocations.
@pytest.fixture
def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=sqlite_memory, table_name="graph_executions"
@pytest.fixture(autouse=True)
async def mock_services(db_conn : sqlite3.Connection) -> InvocationServices:
app = FastAPI()
event_handler_id: int = id(app)
app.add_middleware(
EventHandlerASGIMiddleware,
handlers=[local_handler], # TODO: consider doing this in services to support different configurations
middleware_id=event_handler_id,
)
client = AsyncClient(app=app)
# NOTE: none of these are actually called by the test invocations
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
events = FastAPIEventService(event_handler_id)
return InvocationServices(
model_manager=None, # type: ignore
events=TestEventService(),
events=events,
logger=None, # type: ignore
images=None, # type: ignore
latents=None, # type: ignore
batch_manager=BatchManager(batch_manager_storage),
boards=None, # type: ignore
board_images=None, # type: ignore
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"),
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
performance_statistics=InvocationStatsService(graph_execution_manager),
@ -130,3 +249,135 @@ def test_handles_errors(mock_invoker: Invoker):
assert g.is_complete()
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
def test_can_create_batch_with_subgraph(mock_invoker: Invoker, graph_with_subgraph, batch_with_subgraph):
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
batch=batch_with_subgraph,
graph=graph_with_subgraph,
)
assert batch_process_res.batch_id
# TODO: without the mock events service emitting the `graph_execution_state` events,
# the batch sessions do not know when they have finished, so this logic will fail
# assert len(batch_process_res.session_ids) == 25
# mock_invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id)
# def has_executed_all_batches(batch_id: str):
# batch_sessions = mock_invoker.services.batch_manager.get_sessions(batch_id)
# print(batch_sessions)
# return all((s.state == "completed" for s in batch_sessions))
# wait_until(lambda: has_executed_all_batches(batch_process_res.batch_id), timeout=10, interval=1)
async def test_can_run_batch_with_subgraph(mock_invoker: Invoker, graph_with_subgraph, batch_with_subgraph):
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
batch=batch_with_subgraph,
graph=graph_with_subgraph,
)
mock_invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id)
sessions = []
attempts = 0
import asyncio
while len(sessions) != 25 and attempts < 20:
batch = mock_invoker.services.batch_manager.get_batch(batch_process_res.batch_id)
sessions = mock_invoker.services.batch_manager.get_sessions(batch_process_res.batch_id)
await asyncio.sleep(1)
attempts += 1
assert len(sessions) == 25
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch):
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
batch=simple_batch,
graph=simple_graph,
)
assert batch_process_res.batch_id
# TODO: without the mock events service emitting the `graph_execution_state` events,
# the batch sessions do not know when they have finished, so this logic will fail
# assert len(batch_process_res.session_ids) == 25
# mock_invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id)
# def has_executed_all_batches(batch_id: str):
# batch_sessions = mock_invoker.services.batch_manager.get_sessions(batch_id)
# print(batch_sessions)
# return all((s.state == "completed" for s in batch_sessions))
# wait_until(lambda: has_executed_all_batches(batch_process_res.batch_id), timeout=10, interval=1)
def test_cannot_create_bad_batches():
batch = None
try:
batch = Batch( # This batch has a duplicate node_path|fieldname combo
data=[
[
BatchData(
node_path="1",
field_name="prompt",
items=[
"Tomato sushi",
],
)
],
[
BatchData(
node_path="1",
field_name="prompt",
items=[
"Ume sushi",
],
)
],
]
)
except Exception as e:
assert e
try:
batch = Batch( # This batch has different item list lengths in the same group
data=[
[
BatchData(
node_path="1",
field_name="prompt",
items=[
"Tomato sushi",
],
),
BatchData(
node_path="1",
field_name="prompt",
items=[
"Tomato sushi",
"Courgette sushi",
],
),
],
[
BatchData(
node_path="1",
field_name="prompt",
items=[
"Ume sushi",
],
)
],
]
)
except Exception as e:
assert e
try:
batch = Batch( # This batch has a type mismatch in single items list
data=[
[
BatchData(
node_path="1",
field_name="prompt",
items=["Tomato sushi", 5],
),
],
]
)
except Exception as e:
assert e
assert not batch

View File

@ -51,6 +51,7 @@ class ImageTestInvocationOutput(BaseInvocationOutput):
@invocation("test_text_to_image")
class TextToImageTestInvocation(BaseInvocation):
prompt: str = Field(default="")
prompt2: str = Field(default="")
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))

View File

@ -1,20 +1,27 @@
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from pydantic import BaseModel, Field
import pytest
import sqlite3
class TestModel(BaseModel):
id: str = Field(description="ID")
name: str = Field(description="Name")
def test_sqlite_service_can_create_and_get():
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
@pytest.fixture
def db() -> SqliteItemStorage[TestModel]:
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
return SqliteItemStorage[TestModel](db_conn, "test", "id")
def test_sqlite_service_can_create_and_get(db: SqliteItemStorage[TestModel]):
db.set(TestModel(id="1", name="Test"))
assert db.get("1") == TestModel(id="1", name="Test")
def test_sqlite_service_can_list():
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
def test_sqlite_service_can_list(db: SqliteItemStorage[TestModel]):
db.set(TestModel(id="1", name="Test"))
db.set(TestModel(id="2", name="Test"))
db.set(TestModel(id="3", name="Test"))
@ -30,15 +37,13 @@ def test_sqlite_service_can_list():
]
def test_sqlite_service_can_delete():
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
def test_sqlite_service_can_delete(db: SqliteItemStorage[TestModel]):
db.set(TestModel(id="1", name="Test"))
db.delete("1")
assert db.get("1") is None
def test_sqlite_service_calls_set_callback():
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
def test_sqlite_service_calls_set_callback(db: SqliteItemStorage[TestModel]):
called = False
def on_changed(item: TestModel):
@ -50,8 +55,7 @@ def test_sqlite_service_calls_set_callback():
assert called
def test_sqlite_service_calls_delete_callback():
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
def test_sqlite_service_calls_delete_callback(db: SqliteItemStorage[TestModel]):
called = False
def on_deleted(item_id: str):
@ -64,8 +68,7 @@ def test_sqlite_service_calls_delete_callback():
assert called
def test_sqlite_service_can_list_with_pagination():
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
def test_sqlite_service_can_list_with_pagination(db: SqliteItemStorage[TestModel]):
db.set(TestModel(id="1", name="Test"))
db.set(TestModel(id="2", name="Test"))
db.set(TestModel(id="3", name="Test"))
@ -77,8 +80,7 @@ def test_sqlite_service_can_list_with_pagination():
assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")]
def test_sqlite_service_can_list_with_pagination_and_offset():
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
def test_sqlite_service_can_list_with_pagination_and_offset(db: SqliteItemStorage[TestModel]):
db.set(TestModel(id="1", name="Test"))
db.set(TestModel(id="2", name="Test"))
db.set(TestModel(id="3", name="Test"))
@ -90,8 +92,7 @@ def test_sqlite_service_can_list_with_pagination_and_offset():
assert results.items == [TestModel(id="3", name="Test")]
def test_sqlite_service_can_search():
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
def test_sqlite_service_can_search(db: SqliteItemStorage[TestModel]):
db.set(TestModel(id="1", name="Test"))
db.set(TestModel(id="2", name="Test"))
db.set(TestModel(id="3", name="Test"))
@ -107,8 +108,7 @@ def test_sqlite_service_can_search():
]
def test_sqlite_service_can_search_with_pagination():
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
def test_sqlite_service_can_search_with_pagination(db: SqliteItemStorage[TestModel]):
db.set(TestModel(id="1", name="Test"))
db.set(TestModel(id="2", name="Test"))
db.set(TestModel(id="3", name="Test"))
@ -120,8 +120,7 @@ def test_sqlite_service_can_search_with_pagination():
assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")]
def test_sqlite_service_can_search_with_pagination_and_offset():
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
def test_sqlite_service_can_search_with_pagination_and_offset(db: SqliteItemStorage[TestModel]):
db.set(TestModel(id="1", name="Test"))
db.set(TestModel(id="2", name="Test"))
db.set(TestModel(id="3", name="Test"))