mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
87 Commits
lstein/fea
...
feat/batch
Author | SHA1 | Date | |
---|---|---|---|
732780c376 | |||
ed7deee8f1 | |||
d22c4734ee | |||
3e7dadd7b3 | |||
b777dba430 | |||
531c3bb1e2 | |||
331743ca0c | |||
13429e66b3 | |||
2185c85287 | |||
e8a4a654ac | |||
26f9ac9f21 | |||
8d78af5db7 | |||
babd26feab | |||
e9b26e5e7d | |||
6b946f53c4 | |||
70479b9827 | |||
35099dcdd8 | |||
670600a863 | |||
6d5403e19d | |||
0f7695a081 | |||
d567d9f804 | |||
68f6140685 | |||
be971617e3 | |||
1652143671 | |||
88ae19a768 | |||
50816432dc | |||
b98c9b516a | |||
a15a5bc3b8 | |||
018ff56314 | |||
137fbacb92 | |||
4b6d9a73ed | |||
3e26214b83 | |||
0282f46c71 | |||
99e03fe92e | |||
cb65526880 | |||
59bc9ed399 | |||
e62d5478fd | |||
2cf0d61b3e | |||
cc3c2756bd | |||
67cf594bb3 | |||
c5b963f1a6 | |||
4d2dd6bb10 | |||
7e4beab4ff | |||
e16b5f7cdc | |||
1f355d5810 | |||
df7370f9d9 | |||
5bec64d65b | |||
8cf9bd47b2 | |||
c91621b46c | |||
f246b236dd | |||
f7277a8b21 | |||
796ee1246b | |||
29fceb960d | |||
796ff34c8a | |||
d6a5c2dbe3 | |||
ef8dc2e8c5 | |||
314891a125 | |||
2d3094f988 | |||
abf09fc8fa | |||
15e7ca1baa | |||
6cb90e01de | |||
faa4574970 | |||
cc5755d5b1 | |||
85105fc070 | |||
ed40aee4c5 | |||
f8d8b16267 | |||
846e52f2ea | |||
69f541075c | |||
1debc31e3d | |||
1d798d4119 | |||
c1dde83abb | |||
280ac15da2 | |||
e751f7d815 | |||
e26e4740b3 | |||
835d76af45 | |||
a3e099bbc0 | |||
a61685696f | |||
02aa93c67c | |||
55b921818d | |||
bb681a8a11 | |||
74e0fbce42 | |||
f080c56771 | |||
d2f968b902 | |||
e81601acf3 | |||
7073dc0d5d | |||
d090be60e8 | |||
4bad96d9d6 |
@ -1,6 +1,7 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from logging import Logger
|
||||
import sqlite3
|
||||
from invokeai.app.services.board_image_record_storage import (
|
||||
SqliteBoardImageRecordStorage,
|
||||
)
|
||||
@ -28,6 +29,8 @@ from ..services.invoker import Invoker
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
from ..services.model_manager_service import ModelManagerService
|
||||
from ..services.batch_manager import BatchManager
|
||||
from ..services.batch_manager_storage import SqliteBatchProcessStorage
|
||||
from ..services.invocation_stats import InvocationStatsService
|
||||
from .events import FastAPIEventService
|
||||
|
||||
@ -71,18 +74,18 @@ class ApiDependencies:
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
db_location = str(db_path)
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
)
|
||||
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
||||
|
||||
urls = LocalUrlService()
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
image_record_storage = SqliteImageRecordStorage(conn=db_conn)
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
names = SimpleNameService()
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||
|
||||
board_record_storage = SqliteBoardRecordStorage(db_location)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
||||
board_record_storage = SqliteBoardRecordStorage(conn=db_conn)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn)
|
||||
|
||||
boards = BoardService(
|
||||
services=BoardServiceDependencies(
|
||||
@ -116,15 +119,19 @@ class ApiDependencies:
|
||||
)
|
||||
)
|
||||
|
||||
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
|
||||
batch_manager = BatchManager(batch_manager_storage)
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=ModelManagerService(config, logger),
|
||||
events=events,
|
||||
latents=latents,
|
||||
images=images,
|
||||
batch_manager=batch_manager,
|
||||
boards=boards,
|
||||
board_images=board_images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
configuration=config,
|
||||
|
106
invokeai/app/api/routers/batches.py
Normal file
106
invokeai/app/api/routers/batches.py
Normal file
@ -0,0 +1,106 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from fastapi import Body, HTTPException, Path, Response
|
||||
from fastapi.routing import APIRouter
|
||||
|
||||
from invokeai.app.services.batch_manager_storage import BatchSession, BatchSessionNotFoundException
|
||||
|
||||
# Importing * is bad karma but needed here for node detection
|
||||
from ...invocations import * # noqa: F401 F403
|
||||
from ...services.batch_manager import Batch, BatchProcessResponse
|
||||
from ...services.graph import Graph
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
batches_router = APIRouter(prefix="/v1/batches", tags=["sessions"])
|
||||
|
||||
|
||||
@batches_router.post(
|
||||
"/",
|
||||
operation_id="create_batch",
|
||||
responses={
|
||||
200: {"model": BatchProcessResponse},
|
||||
400: {"description": "Invalid json"},
|
||||
},
|
||||
)
|
||||
async def create_batch(
|
||||
graph: Graph = Body(description="The graph to initialize the session with"),
|
||||
batch: Batch = Body(description="Batch config to apply to the given graph"),
|
||||
) -> BatchProcessResponse:
|
||||
"""Creates a batch process"""
|
||||
return ApiDependencies.invoker.services.batch_manager.create_batch_process(batch, graph)
|
||||
|
||||
|
||||
@batches_router.put(
|
||||
"/b/{batch_process_id}/invoke",
|
||||
operation_id="start_batch",
|
||||
responses={
|
||||
202: {"description": "Batch process started"},
|
||||
404: {"description": "Batch session not found"},
|
||||
},
|
||||
)
|
||||
async def start_batch(
|
||||
batch_process_id: str = Path(description="ID of Batch to start"),
|
||||
) -> Response:
|
||||
"""Executes a batch process"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_process_id)
|
||||
return Response(status_code=202)
|
||||
except BatchSessionNotFoundException:
|
||||
raise HTTPException(status_code=404, detail="Batch session not found")
|
||||
|
||||
|
||||
@batches_router.delete(
|
||||
"/b/{batch_process_id}",
|
||||
operation_id="cancel_batch",
|
||||
responses={202: {"description": "The batch is canceled"}},
|
||||
)
|
||||
async def cancel_batch(
|
||||
batch_process_id: str = Path(description="The id of the batch process to cancel"),
|
||||
) -> Response:
|
||||
"""Cancels a batch process"""
|
||||
ApiDependencies.invoker.services.batch_manager.cancel_batch_process(batch_process_id)
|
||||
return Response(status_code=202)
|
||||
|
||||
|
||||
@batches_router.get(
|
||||
"/incomplete",
|
||||
operation_id="list_incomplete_batches",
|
||||
responses={200: {"model": list[BatchProcessResponse]}},
|
||||
)
|
||||
async def list_incomplete_batches() -> list[BatchProcessResponse]:
|
||||
"""Lists incomplete batch processes"""
|
||||
return ApiDependencies.invoker.services.batch_manager.get_incomplete_batch_processes()
|
||||
|
||||
|
||||
@batches_router.get(
|
||||
"/",
|
||||
operation_id="list_batches",
|
||||
responses={200: {"model": list[BatchProcessResponse]}},
|
||||
)
|
||||
async def list_batches() -> list[BatchProcessResponse]:
|
||||
"""Lists all batch processes"""
|
||||
return ApiDependencies.invoker.services.batch_manager.get_batch_processes()
|
||||
|
||||
|
||||
@batches_router.get(
|
||||
"/b/{batch_process_id}",
|
||||
operation_id="get_batch",
|
||||
responses={200: {"model": BatchProcessResponse}},
|
||||
)
|
||||
async def get_batch(
|
||||
batch_process_id: str = Path(description="The id of the batch process to get"),
|
||||
) -> BatchProcessResponse:
|
||||
"""Gets a Batch Process"""
|
||||
return ApiDependencies.invoker.services.batch_manager.get_batch(batch_process_id)
|
||||
|
||||
|
||||
@batches_router.get(
|
||||
"/b/{batch_process_id}/sessions",
|
||||
operation_id="get_batch_sessions",
|
||||
responses={200: {"model": list[BatchSession]}},
|
||||
)
|
||||
async def get_batch_sessions(
|
||||
batch_process_id: str = Path(description="The id of the batch process to get"),
|
||||
) -> list[BatchSession]:
|
||||
"""Gets a list of batch sessions for a given batch process"""
|
||||
return ApiDependencies.invoker.services.batch_manager.get_sessions(batch_process_id)
|
@ -9,13 +9,7 @@ from pydantic.fields import Field
|
||||
# Importing * is bad karma but needed here for node detection
|
||||
from ...invocations import * # noqa: F401 F403
|
||||
from ...invocations.baseinvocation import BaseInvocation
|
||||
from ...services.graph import (
|
||||
Edge,
|
||||
EdgeConnection,
|
||||
Graph,
|
||||
GraphExecutionState,
|
||||
NodeAlreadyExecutedError,
|
||||
)
|
||||
from ...services.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
|
||||
from ...services.item_storage import PaginatedResults
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
|
@ -13,11 +13,15 @@ class SocketIO:
|
||||
|
||||
def __init__(self, app: FastAPI):
|
||||
self.__sio = SocketManager(app=app)
|
||||
self.__sio.on("subscribe", handler=self._handle_sub)
|
||||
self.__sio.on("unsubscribe", handler=self._handle_unsub)
|
||||
|
||||
self.__sio.on("subscribe_session", handler=self._handle_sub_session)
|
||||
self.__sio.on("unsubscribe_session", handler=self._handle_unsub_session)
|
||||
local_handler.register(event_name=EventServiceBase.session_event, _func=self._handle_session_event)
|
||||
|
||||
self.__sio.on("subscribe_batch", handler=self._handle_sub_batch)
|
||||
self.__sio.on("unsubscribe_batch", handler=self._handle_unsub_batch)
|
||||
local_handler.register(event_name=EventServiceBase.batch_event, _func=self._handle_batch_event)
|
||||
|
||||
async def _handle_session_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event=event[1]["event"],
|
||||
@ -25,12 +29,25 @@ class SocketIO:
|
||||
room=event[1]["data"]["graph_execution_state_id"],
|
||||
)
|
||||
|
||||
async def _handle_sub(self, sid, data, *args, **kwargs):
|
||||
async def _handle_sub_session(self, sid, data, *args, **kwargs):
|
||||
if "session" in data:
|
||||
self.__sio.enter_room(sid, data["session"])
|
||||
|
||||
# @app.sio.on('unsubscribe')
|
||||
|
||||
async def _handle_unsub(self, sid, data, *args, **kwargs):
|
||||
async def _handle_unsub_session(self, sid, data, *args, **kwargs):
|
||||
if "session" in data:
|
||||
self.__sio.leave_room(sid, data["session"])
|
||||
|
||||
async def _handle_batch_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event=event[1]["event"],
|
||||
data=event[1]["data"],
|
||||
room=event[1]["data"]["batch_id"],
|
||||
)
|
||||
|
||||
async def _handle_sub_batch(self, sid, data, *args, **kwargs):
|
||||
if "batch_id" in data:
|
||||
self.__sio.enter_room(sid, data["batch_id"])
|
||||
|
||||
async def _handle_unsub_batch(self, sid, data, *args, **kwargs):
|
||||
if "batch_id" in data:
|
||||
self.__sio.enter_room(sid, data["batch_id"])
|
||||
|
@ -24,7 +24,7 @@ import invokeai.frontend.web as web_dir
|
||||
import mimetypes
|
||||
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import sessions, models, images, boards, board_images, app_info
|
||||
from .api.routers import sessions, batches, models, images, boards, board_images, app_info
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
|
||||
|
||||
@ -90,6 +90,8 @@ async def shutdown_event():
|
||||
|
||||
app.include_router(sessions.session_router, prefix="/api")
|
||||
|
||||
app.include_router(batches.batches_router, prefix="/api")
|
||||
|
||||
app.include_router(models.models_router, prefix="/api")
|
||||
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
|
@ -5,6 +5,7 @@ import re
|
||||
import shlex
|
||||
import sys
|
||||
import time
|
||||
import sqlite3
|
||||
from typing import Union, get_type_hints, Optional
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
@ -29,6 +30,8 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
from invokeai.app.services.batch_manager import BatchManager
|
||||
from invokeai.app.services.batch_manager_storage import SqliteBatchProcessStorage
|
||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
@ -252,19 +255,18 @@ def invoke_cli():
|
||||
db_location = config.db_path
|
||||
db_location.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
|
||||
logger.info(f'InvokeAI database location is "{db_location}"')
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
)
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
||||
|
||||
urls = LocalUrlService()
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
image_record_storage = SqliteImageRecordStorage(conn=db_conn)
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
names = SimpleNameService()
|
||||
|
||||
board_record_storage = SqliteBoardRecordStorage(db_location)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
||||
board_record_storage = SqliteBoardRecordStorage(conn=db_conn)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn)
|
||||
|
||||
boards = BoardService(
|
||||
services=BoardServiceDependencies(
|
||||
@ -298,15 +300,19 @@ def invoke_cli():
|
||||
)
|
||||
)
|
||||
|
||||
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
|
||||
batch_manager = BatchManager(batch_manager_storage)
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=model_manager,
|
||||
events=events,
|
||||
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
|
||||
images=images,
|
||||
boards=boards,
|
||||
batch_manager=batch_manager,
|
||||
board_images=board_images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||
|
215
invokeai/app/services/batch_manager.py
Normal file
215
invokeai/app/services/batch_manager.py
Normal file
@ -0,0 +1,215 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from itertools import product
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.batch_manager_storage import (
|
||||
Batch,
|
||||
BatchProcess,
|
||||
BatchProcessStorageBase,
|
||||
BatchSession,
|
||||
BatchSessionChanges,
|
||||
BatchSessionNotFoundException,
|
||||
)
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.graph import Graph, GraphExecutionState
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
|
||||
class BatchProcessResponse(BaseModel):
|
||||
batch_id: str = Field(description="ID for the batch")
|
||||
session_ids: list[str] = Field(description="List of session IDs created for this batch")
|
||||
|
||||
|
||||
class BatchManagerBase(ABC):
|
||||
@abstractmethod
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
"""Starts the BatchManager service"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_batch_process(self, batch: Batch, graph: Graph) -> BatchProcessResponse:
|
||||
"""Creates a batch process"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run_batch_process(self, batch_id: str) -> None:
|
||||
"""Runs a batch process"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_batch_process(self, batch_process_id: str) -> None:
|
||||
"""Cancels a batch process"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_batch(self, batch_id: str) -> BatchProcessResponse:
|
||||
"""Gets a batch process"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_batch_processes(self) -> list[BatchProcessResponse]:
|
||||
"""Gets all batch processes"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_incomplete_batch_processes(self) -> list[BatchProcessResponse]:
|
||||
"""Gets all incomplete batch processes"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_sessions(self, batch_id: str) -> list[BatchSession]:
|
||||
"""Gets the sessions associated with a batch"""
|
||||
pass
|
||||
|
||||
|
||||
class BatchManager(BatchManagerBase):
|
||||
"""Responsible for managing currently running and scheduled batch jobs"""
|
||||
|
||||
__invoker: Invoker
|
||||
__batch_process_storage: BatchProcessStorageBase
|
||||
|
||||
def __init__(self, batch_process_storage: BatchProcessStorageBase) -> None:
|
||||
super().__init__()
|
||||
self.__batch_process_storage = batch_process_storage
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self.__invoker = invoker
|
||||
local_handler.register(event_name=EventServiceBase.session_event, _func=self.on_event)
|
||||
|
||||
async def on_event(self, event: Event):
|
||||
event_name = event[1]["event"]
|
||||
|
||||
match event_name:
|
||||
case "graph_execution_state_complete":
|
||||
await self._process(event, False)
|
||||
case "invocation_error":
|
||||
await self._process(event, True)
|
||||
|
||||
return event
|
||||
|
||||
async def _process(self, event: Event, err: bool) -> None:
|
||||
data = event[1]["data"]
|
||||
try:
|
||||
batch_session = self.__batch_process_storage.get_session_by_session_id(data["graph_execution_state_id"])
|
||||
except BatchSessionNotFoundException:
|
||||
return None
|
||||
changes = BatchSessionChanges(state="error" if err else "completed")
|
||||
batch_session = self.__batch_process_storage.update_session_state(
|
||||
batch_session.batch_id,
|
||||
batch_session.session_id,
|
||||
changes,
|
||||
)
|
||||
sessions = self.get_sessions(batch_session.batch_id)
|
||||
batch_process = self.__batch_process_storage.get(batch_session.batch_id)
|
||||
if not batch_process.canceled:
|
||||
self.run_batch_process(batch_process.batch_id)
|
||||
|
||||
def _create_graph_execution_state(
|
||||
self, batch_process: BatchProcess, batch_indices: tuple[int, ...]
|
||||
) -> GraphExecutionState:
|
||||
graph = batch_process.graph.copy(deep=True)
|
||||
batch = batch_process.batch
|
||||
for index, bdl in enumerate(batch.data):
|
||||
for bd in bdl:
|
||||
node = graph.get_node(bd.node_path)
|
||||
if node is None:
|
||||
continue
|
||||
batch_index = batch_indices[index]
|
||||
datum = bd.items[batch_index]
|
||||
key = bd.field_name
|
||||
node.__dict__[key] = datum
|
||||
graph.update_node(bd.node_path, node)
|
||||
|
||||
return GraphExecutionState(graph=graph)
|
||||
|
||||
def run_batch_process(self, batch_id: str) -> None:
|
||||
self.__batch_process_storage.start(batch_id)
|
||||
batch_process = self.__batch_process_storage.get(batch_id)
|
||||
next_batch_index = self._get_batch_index_tuple(batch_process)
|
||||
if next_batch_index is None:
|
||||
# finished with current run
|
||||
if batch_process.current_run >= (batch_process.batch.runs - 1):
|
||||
# finished with all runs
|
||||
return
|
||||
batch_process.current_batch_index = 0
|
||||
batch_process.current_run += 1
|
||||
next_batch_index = self._get_batch_index_tuple(batch_process)
|
||||
if next_batch_index is None:
|
||||
# shouldn't happen; satisfy types
|
||||
return
|
||||
# remember to increment the batch index
|
||||
batch_process.current_batch_index += 1
|
||||
self.__batch_process_storage.save(batch_process)
|
||||
ges = self._create_graph_execution_state(batch_process=batch_process, batch_indices=next_batch_index)
|
||||
next_session = self.__batch_process_storage.create_session(
|
||||
BatchSession(
|
||||
batch_id=batch_id,
|
||||
session_id=str(uuid4()),
|
||||
state="uninitialized",
|
||||
batch_index=batch_process.current_batch_index,
|
||||
)
|
||||
)
|
||||
ges.id = next_session.session_id
|
||||
self.__invoker.services.graph_execution_manager.set(ges)
|
||||
self.__batch_process_storage.update_session_state(
|
||||
batch_id=next_session.batch_id,
|
||||
session_id=next_session.session_id,
|
||||
changes=BatchSessionChanges(state="in_progress"),
|
||||
)
|
||||
self.__invoker.services.events.emit_batch_session_created(next_session.batch_id, next_session.session_id)
|
||||
self.__invoker.invoke(ges, invoke_all=True)
|
||||
|
||||
def create_batch_process(self, batch: Batch, graph: Graph) -> BatchProcessResponse:
|
||||
batch_process = BatchProcess(
|
||||
batch=batch,
|
||||
graph=graph,
|
||||
)
|
||||
batch_process = self.__batch_process_storage.save(batch_process)
|
||||
return BatchProcessResponse(
|
||||
batch_id=batch_process.batch_id,
|
||||
session_ids=[],
|
||||
)
|
||||
|
||||
def get_sessions(self, batch_id: str) -> list[BatchSession]:
|
||||
return self.__batch_process_storage.get_sessions_by_batch_id(batch_id)
|
||||
|
||||
def get_batch(self, batch_id: str) -> BatchProcess:
|
||||
return self.__batch_process_storage.get(batch_id)
|
||||
|
||||
def get_batch_processes(self) -> list[BatchProcessResponse]:
|
||||
bps = self.__batch_process_storage.get_all()
|
||||
return self._get_batch_process_responses(bps)
|
||||
|
||||
def get_incomplete_batch_processes(self) -> list[BatchProcessResponse]:
|
||||
bps = self.__batch_process_storage.get_incomplete()
|
||||
return self._get_batch_process_responses(bps)
|
||||
|
||||
def cancel_batch_process(self, batch_process_id: str) -> None:
|
||||
self.__batch_process_storage.cancel(batch_process_id)
|
||||
|
||||
def _get_batch_process_responses(self, batch_processes: list[BatchProcess]) -> list[BatchProcessResponse]:
|
||||
sessions = list()
|
||||
res: list[BatchProcessResponse] = list()
|
||||
for bp in batch_processes:
|
||||
sessions = self.__batch_process_storage.get_sessions_by_batch_id(bp.batch_id)
|
||||
res.append(
|
||||
BatchProcessResponse(
|
||||
batch_id=bp.batch_id,
|
||||
session_ids=[session.session_id for session in sessions],
|
||||
)
|
||||
)
|
||||
return res
|
||||
|
||||
def _get_batch_index_tuple(self, batch_process: BatchProcess) -> Optional[tuple[int, ...]]:
|
||||
batch_indices = list()
|
||||
for batchdata in batch_process.batch.data:
|
||||
batch_indices.append(list(range(len(batchdata[0].items))))
|
||||
try:
|
||||
return list(product(*batch_indices))[batch_process.current_batch_index]
|
||||
except IndexError:
|
||||
return None
|
707
invokeai/app/services/batch_manager_storage.py
Normal file
707
invokeai/app/services/batch_manager_storage.py
Normal file
@ -0,0 +1,707 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Literal, Optional, Union, cast
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr, parse_raw_as, validator
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.services.graph import Graph
|
||||
|
||||
BatchDataType = Union[StrictStr, StrictInt, StrictFloat, ImageField]
|
||||
|
||||
|
||||
class BatchData(BaseModel):
|
||||
"""
|
||||
A batch data collection.
|
||||
"""
|
||||
|
||||
node_path: str = Field(description="The node into which this batch data collection will be substituted.")
|
||||
field_name: str = Field(description="The field into which this batch data collection will be substituted.")
|
||||
items: list[BatchDataType] = Field(
|
||||
default_factory=list, description="The list of items to substitute into the node/field."
|
||||
)
|
||||
|
||||
|
||||
class Batch(BaseModel):
|
||||
"""
|
||||
A batch, consisting of a list of a list of batch data collections.
|
||||
|
||||
First, each inner list[BatchData] is zipped into a single batch data collection.
|
||||
|
||||
Then, the final batch collection is created by taking the Cartesian product of all batch data collections.
|
||||
"""
|
||||
|
||||
data: list[list[BatchData]] = Field(default_factory=list, description="The list of batch data collections.")
|
||||
runs: int = Field(default=1, description="Int stating how many times to iterate through all possible batch indices")
|
||||
|
||||
@validator("runs")
|
||||
def validate_positive_runs(cls, r: int):
|
||||
if r < 1:
|
||||
raise ValueError("runs must be a positive integer")
|
||||
return r
|
||||
|
||||
@validator("data")
|
||||
def validate_len(cls, v: list[list[BatchData]]):
|
||||
for batch_data in v:
|
||||
if any(len(batch_data[0].items) != len(i.items) for i in batch_data):
|
||||
raise ValueError("Zipped batch items must have all have same length")
|
||||
return v
|
||||
|
||||
@validator("data")
|
||||
def validate_types(cls, v: list[list[BatchData]]):
|
||||
for batch_data in v:
|
||||
for datum in batch_data:
|
||||
for item in datum.items:
|
||||
if not all(isinstance(item, type(i)) for i in datum.items):
|
||||
raise TypeError("All items in a batch must have have same type")
|
||||
return v
|
||||
|
||||
@validator("data")
|
||||
def validate_unique_field_mappings(cls, v: list[list[BatchData]]):
|
||||
paths: set[tuple[str, str]] = set()
|
||||
count: int = 0
|
||||
for batch_data in v:
|
||||
for datum in batch_data:
|
||||
paths.add((datum.node_path, datum.field_name))
|
||||
count += 1
|
||||
if len(paths) != count:
|
||||
raise ValueError("Each batch data must have unique node_id and field_name")
|
||||
return v
|
||||
|
||||
|
||||
def uuid_string():
|
||||
res = uuid.uuid4()
|
||||
return str(res)
|
||||
|
||||
|
||||
BATCH_SESSION_STATE = Literal["uninitialized", "in_progress", "completed", "error"]
|
||||
|
||||
|
||||
class BatchSession(BaseModel):
|
||||
batch_id: str = Field(defaultdescription="The Batch to which this BatchSession is attached.")
|
||||
session_id: str = Field(
|
||||
default_factory=uuid_string, description="The Session to which this BatchSession is attached."
|
||||
)
|
||||
batch_index: int = Field(description="The index of this batch session in its parent batch process")
|
||||
state: BATCH_SESSION_STATE = Field(default="uninitialized", description="The state of this BatchSession")
|
||||
|
||||
|
||||
class BatchProcess(BaseModel):
|
||||
batch_id: str = Field(default_factory=uuid_string, description="Identifier for this batch.")
|
||||
batch: Batch = Field(description="The Batch to apply to this session.")
|
||||
current_batch_index: int = Field(default=0, description="The last executed batch index")
|
||||
current_run: int = Field(default=0, description="The current run of the batch")
|
||||
canceled: bool = Field(description="Whether or not to run sessions from this batch.", default=False)
|
||||
graph: Graph = Field(description="The graph into which batch data will be inserted before being executed.")
|
||||
|
||||
|
||||
class BatchSessionChanges(BaseModel, extra=Extra.forbid):
|
||||
state: BATCH_SESSION_STATE = Field(description="The state of this BatchSession")
|
||||
|
||||
|
||||
class BatchProcessNotFoundException(Exception):
|
||||
"""Raised when an Batch Process record is not found."""
|
||||
|
||||
def __init__(self, message="BatchProcess record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchProcessSaveException(Exception):
|
||||
"""Raised when an Batch Process record cannot be saved."""
|
||||
|
||||
def __init__(self, message="BatchProcess record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchProcessDeleteException(Exception):
|
||||
"""Raised when an Batch Process record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="BatchProcess record not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchSessionNotFoundException(Exception):
|
||||
"""Raised when an Batch Session record is not found."""
|
||||
|
||||
def __init__(self, message="BatchSession record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchSessionSaveException(Exception):
|
||||
"""Raised when an Batch Session record cannot be saved."""
|
||||
|
||||
def __init__(self, message="BatchSession record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchSessionDeleteException(Exception):
|
||||
"""Raised when an Batch Session record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="BatchSession record not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchProcessStorageBase(ABC):
|
||||
"""Low-level service responsible for interfacing with the Batch Process record store."""
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, batch_id: str) -> None:
|
||||
"""Deletes a BatchProcess record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
batch_process: BatchProcess,
|
||||
) -> BatchProcess:
|
||||
"""Saves a BatchProcess record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
batch_id: str,
|
||||
) -> BatchProcess:
|
||||
"""Gets a BatchProcess record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BatchProcess]:
|
||||
"""Gets a BatchProcess record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_incomplete(
|
||||
self,
|
||||
) -> list[BatchProcess]:
|
||||
"""Gets a BatchProcess record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start(
|
||||
self,
|
||||
batch_id: str,
|
||||
) -> None:
|
||||
"""'Starts' a BatchProcess record by marking its `canceled` attribute to False."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel(
|
||||
self,
|
||||
batch_id: str,
|
||||
) -> None:
|
||||
"""'Cancels' a BatchProcess record by setting its `canceled` attribute to True."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_session(
|
||||
self,
|
||||
session: BatchSession,
|
||||
) -> BatchSession:
|
||||
"""Creates a BatchSession attached to a BatchProcess."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_sessions(
|
||||
self,
|
||||
sessions: list[BatchSession],
|
||||
) -> list[BatchSession]:
|
||||
"""Creates many BatchSessions attached to a BatchProcess."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_session_by_session_id(self, session_id: str) -> BatchSession:
|
||||
"""Gets a BatchSession by session_id"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_sessions_by_batch_id(self, batch_id: str) -> List[BatchSession]:
|
||||
"""Gets all BatchSession's for a given BatchProcess id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_sessions_by_session_ids(self, session_ids: list[str]) -> List[BatchSession]:
|
||||
"""Gets all BatchSession's for a given list of session ids."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_next_session(self, batch_id: str) -> BatchSession:
|
||||
"""Gets the next BatchSession with state `uninitialized`, for a given BatchProcess id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_session_state(
|
||||
self,
|
||||
batch_id: str,
|
||||
session_id: str,
|
||||
changes: BatchSessionChanges,
|
||||
) -> BatchSession:
|
||||
"""Updates the state of a BatchSession record."""
|
||||
pass
|
||||
|
||||
|
||||
class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||
super().__init__()
|
||||
self._conn = conn
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the `batch_process` table and `batch_session` junction table."""
|
||||
|
||||
# Create the `batch_process` table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS batch_process (
|
||||
batch_id TEXT NOT NULL PRIMARY KEY,
|
||||
batch TEXT NOT NULL,
|
||||
graph TEXT NOT NULL,
|
||||
current_batch_index NUMBER NOT NULL,
|
||||
current_run NUMBER NOT NULL,
|
||||
canceled BOOLEAN NOT NULL DEFAULT(0),
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_process_created_at ON batch_process (created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_batch_process_updated_at
|
||||
AFTER UPDATE
|
||||
ON batch_process FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE batch_process SET updated_at = current_timestamp
|
||||
WHERE batch_id = old.batch_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
# Create the `batch_session` junction table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS batch_session (
|
||||
batch_id TEXT NOT NULL,
|
||||
session_id TEXT NOT NULL,
|
||||
state TEXT NOT NULL,
|
||||
batch_index NUMBER NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
-- enforce one-to-many relationship between batch_process and batch_session using PK
|
||||
-- (we can extend this to many-to-many later)
|
||||
PRIMARY KEY (batch_id,session_id),
|
||||
FOREIGN KEY (batch_id) REFERENCES batch_process (batch_id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for batch id
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id ON batch_session (batch_id);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for batch id, sorted by created_at
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id_created_at ON batch_session (batch_id,created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_batch_session_updated_at
|
||||
AFTER UPDATE
|
||||
ON batch_session FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE batch_session SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE batch_id = old.batch_id AND session_id = old.session_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def delete(self, batch_id: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM batch_process
|
||||
WHERE batch_id = ?;
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessDeleteException from e
|
||||
except Exception as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessDeleteException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def save(
|
||||
self,
|
||||
batch_process: BatchProcess,
|
||||
) -> BatchProcess:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR REPLACE INTO batch_process (batch_id, batch, graph, current_batch_index, current_run)
|
||||
VALUES (?, ?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
batch_process.batch_id,
|
||||
batch_process.batch.json(),
|
||||
batch_process.graph.json(),
|
||||
batch_process.current_batch_index,
|
||||
batch_process.current_run,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(batch_process.batch_id)
|
||||
|
||||
def _deserialize_batch_process(self, session_dict: dict) -> BatchProcess:
|
||||
"""Deserializes a batch session."""
|
||||
|
||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||
|
||||
batch_id = session_dict.get("batch_id", "unknown")
|
||||
batch_raw = session_dict.get("batch", "unknown")
|
||||
graph_raw = session_dict.get("graph", "unknown")
|
||||
current_batch_index = session_dict.get("current_batch_index", 0)
|
||||
current_run = session_dict.get("current_run", 0)
|
||||
canceled = session_dict.get("canceled", 0)
|
||||
return BatchProcess(
|
||||
batch_id=batch_id,
|
||||
batch=parse_raw_as(Batch, batch_raw),
|
||||
graph=parse_raw_as(Graph, graph_raw),
|
||||
current_batch_index=current_batch_index,
|
||||
current_run=current_run,
|
||||
canceled=canceled == 1,
|
||||
)
|
||||
|
||||
def get(
|
||||
self,
|
||||
batch_id: str,
|
||||
) -> BatchProcess:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_process
|
||||
WHERE batch_id = ?;
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BatchProcessNotFoundException
|
||||
return self._deserialize_batch_process(dict(result))
|
||||
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BatchProcess]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_process
|
||||
"""
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
return list()
|
||||
return list(map(lambda r: self._deserialize_batch_process(dict(r)), result))
|
||||
|
||||
def get_incomplete(
|
||||
self,
|
||||
) -> list[BatchProcess]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT bp.*
|
||||
FROM batch_process bp
|
||||
WHERE bp.batch_id IN
|
||||
(
|
||||
SELECT batch_id
|
||||
FROM batch_session bs
|
||||
WHERE state IN ('uninitialized', 'in_progress')
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
return list()
|
||||
return list(map(lambda r: self._deserialize_batch_process(dict(r)), result))
|
||||
|
||||
def start(
|
||||
self,
|
||||
batch_id: str,
|
||||
) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE batch_process
|
||||
SET canceled = 0
|
||||
WHERE batch_id = ?;
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def cancel(
|
||||
self,
|
||||
batch_id: str,
|
||||
) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE batch_process
|
||||
SET canceled = 1
|
||||
WHERE batch_id = ?;
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
session: BatchSession,
|
||||
) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
|
||||
VALUES (?, ?, ?, ?);
|
||||
""",
|
||||
(session.batch_id, session.session_id, session.state, session.batch_index),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get_session_by_session_id(session.session_id)
|
||||
|
||||
def create_sessions(
|
||||
self,
|
||||
sessions: list[BatchSession],
|
||||
) -> list[BatchSession]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
session_data = [(session.batch_id, session.session_id, session.state) for session in sessions]
|
||||
self._cursor.executemany(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
|
||||
VALUES (?, ?, ?, ?);
|
||||
""",
|
||||
session_data,
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get_sessions_by_session_ids([session.session_id for session in sessions])
|
||||
|
||||
def get_session_by_session_id(self, session_id: str) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_session
|
||||
WHERE session_id= ?;
|
||||
""",
|
||||
(session_id,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BatchSessionNotFoundException
|
||||
return self._deserialize_batch_session(dict(result))
|
||||
|
||||
def _deserialize_batch_session(self, session_dict: dict) -> BatchSession:
|
||||
"""Deserializes a batch session."""
|
||||
|
||||
return BatchSession.parse_obj(session_dict)
|
||||
|
||||
def get_next_session(self, batch_id: str) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_session
|
||||
WHERE batch_id = ? AND state = 'uninitialized';
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
|
||||
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BatchSessionNotFoundException
|
||||
session = self._deserialize_batch_session(dict(result))
|
||||
return session
|
||||
|
||||
def get_sessions_by_batch_id(self, batch_id: str) -> List[BatchSession]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_session
|
||||
WHERE batch_id = ?;
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BatchSessionNotFoundException
|
||||
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
|
||||
return sessions
|
||||
|
||||
def get_sessions_by_session_ids(self, session_ids: list[str]) -> List[BatchSession]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
placeholders = ",".join("?" * len(session_ids))
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT * FROM batch_session
|
||||
WHERE session_id
|
||||
IN ({placeholders})
|
||||
""",
|
||||
tuple(session_ids),
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BatchSessionNotFoundException
|
||||
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
|
||||
return sessions
|
||||
|
||||
def update_session_state(
|
||||
self,
|
||||
batch_id: str,
|
||||
session_id: str,
|
||||
changes: BatchSessionChanges,
|
||||
) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Change the state of a batch session
|
||||
if changes.state is not None:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE batch_session
|
||||
SET state = ?
|
||||
WHERE batch_id = ? AND session_id = ?;
|
||||
""",
|
||||
(changes.state, batch_id, session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get_session_by_session_id(session_id)
|
@ -56,15 +56,13 @@ class BoardImageRecordStorageBase(ABC):
|
||||
|
||||
|
||||
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
self._conn = conn
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
|
@ -89,15 +89,13 @@ class BoardRecordStorageBase(ABC):
|
||||
|
||||
|
||||
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
self._conn = conn
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
|
@ -13,6 +13,7 @@ from invokeai.app.services.model_manager_service import (
|
||||
|
||||
class EventServiceBase:
|
||||
session_event: str = "session_event"
|
||||
batch_event: str = "batch_event"
|
||||
|
||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||
|
||||
@ -20,12 +21,21 @@ class EventServiceBase:
|
||||
pass
|
||||
|
||||
def __emit_session_event(self, event_name: str, payload: dict) -> None:
|
||||
"""Session events are emitted to a room with the session_id as the room name"""
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.session_event,
|
||||
payload=dict(event=event_name, data=payload),
|
||||
)
|
||||
|
||||
def __emit_batch_event(self, event_name: str, payload: dict) -> None:
|
||||
"""Batch events are emitted to a room with the batch_id as the room name"""
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.batch_event,
|
||||
payload=dict(event=event_name, data=payload),
|
||||
)
|
||||
|
||||
# Define events here for every event in the system.
|
||||
# This will make them easier to integrate until we find a schema generator.
|
||||
def emit_generator_progress(
|
||||
@ -187,3 +197,14 @@ class EventServiceBase:
|
||||
error=error,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_batch_session_created(
|
||||
self,
|
||||
batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
) -> None:
|
||||
"""Emitted when a batch session is created"""
|
||||
self.__emit_batch_event(
|
||||
event_name="batch_session_created",
|
||||
payload=dict(batch_id=batch_id, graph_execution_state_id=graph_execution_state_id),
|
||||
)
|
||||
|
@ -152,15 +152,13 @@ class ImageRecordStorageBase(ABC):
|
||||
|
||||
|
||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
self._conn = conn
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
|
@ -4,6 +4,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
from invokeai.app.services.batch_manager import BatchManagerBase
|
||||
from invokeai.app.services.board_images import BoardImagesServiceABC
|
||||
from invokeai.app.services.boards import BoardServiceABC
|
||||
from invokeai.app.services.images import ImageServiceABC
|
||||
@ -22,6 +23,7 @@ class InvocationServices:
|
||||
"""Services that can be used by invocations"""
|
||||
|
||||
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
||||
batch_manager: "BatchManagerBase"
|
||||
board_images: "BoardImagesServiceABC"
|
||||
boards: "BoardServiceABC"
|
||||
configuration: "InvokeAIAppConfig"
|
||||
@ -38,6 +40,7 @@ class InvocationServices:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_manager: "BatchManagerBase",
|
||||
board_images: "BoardImagesServiceABC",
|
||||
boards: "BoardServiceABC",
|
||||
configuration: "InvokeAIAppConfig",
|
||||
@ -52,6 +55,7 @@ class InvocationServices:
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
queue: "InvocationQueueABC",
|
||||
):
|
||||
self.batch_manager = batch_manager
|
||||
self.board_images = board_images
|
||||
self.boards = boards
|
||||
self.boards = boards
|
||||
|
@ -12,23 +12,19 @@ sqlite_memory = ":memory:"
|
||||
|
||||
|
||||
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
_filename: str
|
||||
_table_name: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_id_field: str
|
||||
_lock: Lock
|
||||
|
||||
def __init__(self, filename: str, table_name: str, id_field: str = "id"):
|
||||
def __init__(self, conn: sqlite3.Connection, table_name: str, id_field: str = "id"):
|
||||
super().__init__()
|
||||
|
||||
self._filename = filename
|
||||
self._table_name = table_name
|
||||
self._id_field = id_field # TODO: validate that T has this field
|
||||
self._lock = Lock()
|
||||
self._conn = sqlite3.connect(
|
||||
self._filename, check_same_thread=False
|
||||
) # TODO: figure out a better threading solution
|
||||
self._conn = conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
self._create_table()
|
||||
@ -49,8 +45,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
|
||||
def _parse_item(self, item: str) -> T:
|
||||
item_type = get_args(self.__orig_class__)[0]
|
||||
parsed = parse_raw_as(item_type, item)
|
||||
return parsed
|
||||
return parse_raw_as(item_type, item)
|
||||
|
||||
def set(self, item: T):
|
||||
try:
|
||||
|
344
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
344
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
File diff suppressed because one or more lines are too long
@ -102,6 +102,7 @@ dependencies = [
|
||||
"flake8",
|
||||
"Flake8-pyproject",
|
||||
"pytest>6.0.0",
|
||||
"pytest-asyncio",
|
||||
"pytest-cov",
|
||||
"pytest-datadir",
|
||||
]
|
||||
@ -176,6 +177,7 @@ version = { attr = "invokeai.version.__version__" }
|
||||
#=== Begin: PyTest and Coverage
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--cov-report term --cov-report html --cov-report xml"
|
||||
asyncio_mode = "auto"
|
||||
[tool.coverage.run]
|
||||
branch = true
|
||||
source = ["invokeai"]
|
||||
|
@ -25,6 +25,7 @@ from invokeai.app.services.graph import (
|
||||
LibraryGraph,
|
||||
)
|
||||
import pytest
|
||||
import sqlite3
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -42,9 +43,8 @@ def simple_graph():
|
||||
@pytest.fixture
|
||||
def mock_services() -> InvocationServices:
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=sqlite_memory, table_name="graph_executions"
|
||||
)
|
||||
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
||||
return InvocationServices(
|
||||
model_manager=None, # type: ignore
|
||||
events=TestEventService(),
|
||||
@ -52,9 +52,10 @@ def mock_services() -> InvocationServices:
|
||||
images=None, # type: ignore
|
||||
latents=None, # type: ignore
|
||||
boards=None, # type: ignore
|
||||
batch_manager=None, # type: ignore
|
||||
board_images=None, # type: ignore
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
|
@ -6,18 +6,34 @@ from .test_nodes import (
|
||||
create_edge,
|
||||
wait_until,
|
||||
)
|
||||
# from fastapi_events.handlers.local import
|
||||
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
||||
from invokeai.app.services.processor import DefaultInvocationProcessor
|
||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||
from invokeai.app.api.events import FastAPIEventService
|
||||
from invokeai.app.services.batch_manager_storage import BatchData, SqliteBatchProcessStorage
|
||||
from invokeai.app.services.batch_manager import (
|
||||
Batch,
|
||||
BatchManager,
|
||||
)
|
||||
from invokeai.app.services.graph import (
|
||||
Graph,
|
||||
GraphExecutionState,
|
||||
GraphInvocation,
|
||||
LibraryGraph,
|
||||
)
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import sqlite3
|
||||
import time
|
||||
|
||||
from httpx import AsyncClient
|
||||
from fastapi import FastAPI
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -29,25 +45,128 @@ def simple_graph():
|
||||
return g
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_batch():
|
||||
return Batch(
|
||||
data=[
|
||||
[
|
||||
BatchData(
|
||||
node_path="1",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Tomato sushi",
|
||||
"Strawberry sushi",
|
||||
"Broccoli sushi",
|
||||
"Asparagus sushi",
|
||||
"Tea sushi",
|
||||
],
|
||||
)
|
||||
],
|
||||
[
|
||||
BatchData(
|
||||
node_path="2",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Ume sushi",
|
||||
"Ichigo sushi",
|
||||
"Momo sushi",
|
||||
"Mikan sushi",
|
||||
"Cha sushi",
|
||||
],
|
||||
)
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graph_with_subgraph():
|
||||
sub_g = Graph()
|
||||
sub_g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
|
||||
sub_g.add_node(TextToImageTestInvocation(id="2"))
|
||||
sub_g.add_edge(create_edge("1", "prompt", "2", "prompt"))
|
||||
g = Graph()
|
||||
g.add_node(GraphInvocation(id="1", graph=sub_g))
|
||||
return g
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def batch_with_subgraph():
|
||||
return Batch(
|
||||
data=[
|
||||
[
|
||||
BatchData(
|
||||
node_path="1.1",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Tomato sushi",
|
||||
"Strawberry sushi",
|
||||
"Broccoli sushi",
|
||||
"Asparagus sushi",
|
||||
"Tea sushi",
|
||||
],
|
||||
)
|
||||
],
|
||||
[
|
||||
BatchData(
|
||||
node_path="1.2",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Ume sushi",
|
||||
"Ichigo sushi",
|
||||
"Momo sushi",
|
||||
"Mikan sushi",
|
||||
"Cha sushi",
|
||||
],
|
||||
)
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
# @pytest_asyncio.fixture(scope="module")
|
||||
# def event_loop():
|
||||
# import asyncio
|
||||
# try:
|
||||
# loop = asyncio.get_running_loop()
|
||||
# except RuntimeError as e:
|
||||
# loop = asyncio.new_event_loop()
|
||||
# asyncio.set_event_loop(loop)
|
||||
# # fastapi_events.event_store
|
||||
# yield loop
|
||||
# loop.close()
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def db_conn():
|
||||
return sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||
|
||||
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
|
||||
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
|
||||
# the test invocations.
|
||||
@pytest.fixture
|
||||
def mock_services() -> InvocationServices:
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=sqlite_memory, table_name="graph_executions"
|
||||
@pytest.fixture(autouse=True)
|
||||
async def mock_services(db_conn : sqlite3.Connection) -> InvocationServices:
|
||||
app = FastAPI()
|
||||
event_handler_id: int = id(app)
|
||||
app.add_middleware(
|
||||
EventHandlerASGIMiddleware,
|
||||
handlers=[local_handler], # TODO: consider doing this in services to support different configurations
|
||||
middleware_id=event_handler_id,
|
||||
)
|
||||
client = AsyncClient(app=app)
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
||||
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
return InvocationServices(
|
||||
model_manager=None, # type: ignore
|
||||
events=TestEventService(),
|
||||
events=events,
|
||||
logger=None, # type: ignore
|
||||
images=None, # type: ignore
|
||||
latents=None, # type: ignore
|
||||
batch_manager=BatchManager(batch_manager_storage),
|
||||
boards=None, # type: ignore
|
||||
board_images=None, # type: ignore
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||
@ -130,3 +249,135 @@ def test_handles_errors(mock_invoker: Invoker):
|
||||
assert g.is_complete()
|
||||
|
||||
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
|
||||
|
||||
|
||||
def test_can_create_batch_with_subgraph(mock_invoker: Invoker, graph_with_subgraph, batch_with_subgraph):
|
||||
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
||||
batch=batch_with_subgraph,
|
||||
graph=graph_with_subgraph,
|
||||
)
|
||||
assert batch_process_res.batch_id
|
||||
# TODO: without the mock events service emitting the `graph_execution_state` events,
|
||||
# the batch sessions do not know when they have finished, so this logic will fail
|
||||
|
||||
# assert len(batch_process_res.session_ids) == 25
|
||||
# mock_invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id)
|
||||
|
||||
# def has_executed_all_batches(batch_id: str):
|
||||
# batch_sessions = mock_invoker.services.batch_manager.get_sessions(batch_id)
|
||||
# print(batch_sessions)
|
||||
# return all((s.state == "completed" for s in batch_sessions))
|
||||
|
||||
# wait_until(lambda: has_executed_all_batches(batch_process_res.batch_id), timeout=10, interval=1)
|
||||
|
||||
async def test_can_run_batch_with_subgraph(mock_invoker: Invoker, graph_with_subgraph, batch_with_subgraph):
|
||||
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
||||
batch=batch_with_subgraph,
|
||||
graph=graph_with_subgraph,
|
||||
)
|
||||
mock_invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id)
|
||||
sessions = []
|
||||
attempts = 0
|
||||
import asyncio
|
||||
while len(sessions) != 25 and attempts < 20:
|
||||
batch = mock_invoker.services.batch_manager.get_batch(batch_process_res.batch_id)
|
||||
sessions = mock_invoker.services.batch_manager.get_sessions(batch_process_res.batch_id)
|
||||
await asyncio.sleep(1)
|
||||
attempts += 1
|
||||
assert len(sessions) == 25
|
||||
|
||||
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch):
|
||||
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
||||
batch=simple_batch,
|
||||
graph=simple_graph,
|
||||
)
|
||||
assert batch_process_res.batch_id
|
||||
# TODO: without the mock events service emitting the `graph_execution_state` events,
|
||||
# the batch sessions do not know when they have finished, so this logic will fail
|
||||
|
||||
# assert len(batch_process_res.session_ids) == 25
|
||||
# mock_invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id)
|
||||
|
||||
# def has_executed_all_batches(batch_id: str):
|
||||
# batch_sessions = mock_invoker.services.batch_manager.get_sessions(batch_id)
|
||||
# print(batch_sessions)
|
||||
# return all((s.state == "completed" for s in batch_sessions))
|
||||
|
||||
# wait_until(lambda: has_executed_all_batches(batch_process_res.batch_id), timeout=10, interval=1)
|
||||
|
||||
|
||||
def test_cannot_create_bad_batches():
|
||||
batch = None
|
||||
try:
|
||||
batch = Batch( # This batch has a duplicate node_path|fieldname combo
|
||||
data=[
|
||||
[
|
||||
BatchData(
|
||||
node_path="1",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Tomato sushi",
|
||||
],
|
||||
)
|
||||
],
|
||||
[
|
||||
BatchData(
|
||||
node_path="1",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Ume sushi",
|
||||
],
|
||||
)
|
||||
],
|
||||
]
|
||||
)
|
||||
except Exception as e:
|
||||
assert e
|
||||
try:
|
||||
batch = Batch( # This batch has different item list lengths in the same group
|
||||
data=[
|
||||
[
|
||||
BatchData(
|
||||
node_path="1",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Tomato sushi",
|
||||
],
|
||||
),
|
||||
BatchData(
|
||||
node_path="1",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Tomato sushi",
|
||||
"Courgette sushi",
|
||||
],
|
||||
),
|
||||
],
|
||||
[
|
||||
BatchData(
|
||||
node_path="1",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Ume sushi",
|
||||
],
|
||||
)
|
||||
],
|
||||
]
|
||||
)
|
||||
except Exception as e:
|
||||
assert e
|
||||
try:
|
||||
batch = Batch( # This batch has a type mismatch in single items list
|
||||
data=[
|
||||
[
|
||||
BatchData(
|
||||
node_path="1",
|
||||
field_name="prompt",
|
||||
items=["Tomato sushi", 5],
|
||||
),
|
||||
],
|
||||
]
|
||||
)
|
||||
except Exception as e:
|
||||
assert e
|
||||
assert not batch
|
||||
|
@ -51,6 +51,7 @@ class ImageTestInvocationOutput(BaseInvocationOutput):
|
||||
@invocation("test_text_to_image")
|
||||
class TextToImageTestInvocation(BaseInvocation):
|
||||
prompt: str = Field(default="")
|
||||
prompt2: str = Field(default="")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
||||
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
||||
|
@ -1,20 +1,27 @@
|
||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import pytest
|
||||
import sqlite3
|
||||
|
||||
|
||||
class TestModel(BaseModel):
|
||||
id: str = Field(description="ID")
|
||||
name: str = Field(description="Name")
|
||||
|
||||
|
||||
def test_sqlite_service_can_create_and_get():
|
||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
||||
@pytest.fixture
|
||||
def db() -> SqliteItemStorage[TestModel]:
|
||||
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||
return SqliteItemStorage[TestModel](db_conn, "test", "id")
|
||||
|
||||
|
||||
def test_sqlite_service_can_create_and_get(db: SqliteItemStorage[TestModel]):
|
||||
db.set(TestModel(id="1", name="Test"))
|
||||
assert db.get("1") == TestModel(id="1", name="Test")
|
||||
|
||||
|
||||
def test_sqlite_service_can_list():
|
||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
||||
def test_sqlite_service_can_list(db: SqliteItemStorage[TestModel]):
|
||||
db.set(TestModel(id="1", name="Test"))
|
||||
db.set(TestModel(id="2", name="Test"))
|
||||
db.set(TestModel(id="3", name="Test"))
|
||||
@ -30,15 +37,13 @@ def test_sqlite_service_can_list():
|
||||
]
|
||||
|
||||
|
||||
def test_sqlite_service_can_delete():
|
||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
||||
def test_sqlite_service_can_delete(db: SqliteItemStorage[TestModel]):
|
||||
db.set(TestModel(id="1", name="Test"))
|
||||
db.delete("1")
|
||||
assert db.get("1") is None
|
||||
|
||||
|
||||
def test_sqlite_service_calls_set_callback():
|
||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
||||
def test_sqlite_service_calls_set_callback(db: SqliteItemStorage[TestModel]):
|
||||
called = False
|
||||
|
||||
def on_changed(item: TestModel):
|
||||
@ -50,8 +55,7 @@ def test_sqlite_service_calls_set_callback():
|
||||
assert called
|
||||
|
||||
|
||||
def test_sqlite_service_calls_delete_callback():
|
||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
||||
def test_sqlite_service_calls_delete_callback(db: SqliteItemStorage[TestModel]):
|
||||
called = False
|
||||
|
||||
def on_deleted(item_id: str):
|
||||
@ -64,8 +68,7 @@ def test_sqlite_service_calls_delete_callback():
|
||||
assert called
|
||||
|
||||
|
||||
def test_sqlite_service_can_list_with_pagination():
|
||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
||||
def test_sqlite_service_can_list_with_pagination(db: SqliteItemStorage[TestModel]):
|
||||
db.set(TestModel(id="1", name="Test"))
|
||||
db.set(TestModel(id="2", name="Test"))
|
||||
db.set(TestModel(id="3", name="Test"))
|
||||
@ -77,8 +80,7 @@ def test_sqlite_service_can_list_with_pagination():
|
||||
assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")]
|
||||
|
||||
|
||||
def test_sqlite_service_can_list_with_pagination_and_offset():
|
||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
||||
def test_sqlite_service_can_list_with_pagination_and_offset(db: SqliteItemStorage[TestModel]):
|
||||
db.set(TestModel(id="1", name="Test"))
|
||||
db.set(TestModel(id="2", name="Test"))
|
||||
db.set(TestModel(id="3", name="Test"))
|
||||
@ -90,8 +92,7 @@ def test_sqlite_service_can_list_with_pagination_and_offset():
|
||||
assert results.items == [TestModel(id="3", name="Test")]
|
||||
|
||||
|
||||
def test_sqlite_service_can_search():
|
||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
||||
def test_sqlite_service_can_search(db: SqliteItemStorage[TestModel]):
|
||||
db.set(TestModel(id="1", name="Test"))
|
||||
db.set(TestModel(id="2", name="Test"))
|
||||
db.set(TestModel(id="3", name="Test"))
|
||||
@ -107,8 +108,7 @@ def test_sqlite_service_can_search():
|
||||
]
|
||||
|
||||
|
||||
def test_sqlite_service_can_search_with_pagination():
|
||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
||||
def test_sqlite_service_can_search_with_pagination(db: SqliteItemStorage[TestModel]):
|
||||
db.set(TestModel(id="1", name="Test"))
|
||||
db.set(TestModel(id="2", name="Test"))
|
||||
db.set(TestModel(id="3", name="Test"))
|
||||
@ -120,8 +120,7 @@ def test_sqlite_service_can_search_with_pagination():
|
||||
assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")]
|
||||
|
||||
|
||||
def test_sqlite_service_can_search_with_pagination_and_offset():
|
||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
||||
def test_sqlite_service_can_search_with_pagination_and_offset(db: SqliteItemStorage[TestModel]):
|
||||
db.set(TestModel(id="1", name="Test"))
|
||||
db.set(TestModel(id="2", name="Test"))
|
||||
db.set(TestModel(id="3", name="Test"))
|
||||
|
Reference in New Issue
Block a user