mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
86 Commits
feat/docs_
...
feat/batch
Author | SHA1 | Date | |
---|---|---|---|
ed7deee8f1 | |||
d22c4734ee | |||
3e7dadd7b3 | |||
b777dba430 | |||
531c3bb1e2 | |||
331743ca0c | |||
13429e66b3 | |||
2185c85287 | |||
e8a4a654ac | |||
26f9ac9f21 | |||
8d78af5db7 | |||
babd26feab | |||
e9b26e5e7d | |||
6b946f53c4 | |||
70479b9827 | |||
35099dcdd8 | |||
670600a863 | |||
6d5403e19d | |||
0f7695a081 | |||
d567d9f804 | |||
68f6140685 | |||
be971617e3 | |||
1652143671 | |||
88ae19a768 | |||
50816432dc | |||
b98c9b516a | |||
a15a5bc3b8 | |||
018ff56314 | |||
137fbacb92 | |||
4b6d9a73ed | |||
3e26214b83 | |||
0282f46c71 | |||
99e03fe92e | |||
cb65526880 | |||
59bc9ed399 | |||
e62d5478fd | |||
2cf0d61b3e | |||
cc3c2756bd | |||
67cf594bb3 | |||
c5b963f1a6 | |||
4d2dd6bb10 | |||
7e4beab4ff | |||
e16b5f7cdc | |||
1f355d5810 | |||
df7370f9d9 | |||
5bec64d65b | |||
8cf9bd47b2 | |||
c91621b46c | |||
f246b236dd | |||
f7277a8b21 | |||
796ee1246b | |||
29fceb960d | |||
796ff34c8a | |||
d6a5c2dbe3 | |||
ef8dc2e8c5 | |||
314891a125 | |||
2d3094f988 | |||
abf09fc8fa | |||
15e7ca1baa | |||
6cb90e01de | |||
faa4574970 | |||
cc5755d5b1 | |||
85105fc070 | |||
ed40aee4c5 | |||
f8d8b16267 | |||
846e52f2ea | |||
69f541075c | |||
1debc31e3d | |||
1d798d4119 | |||
c1dde83abb | |||
280ac15da2 | |||
e751f7d815 | |||
e26e4740b3 | |||
835d76af45 | |||
a3e099bbc0 | |||
a61685696f | |||
02aa93c67c | |||
55b921818d | |||
bb681a8a11 | |||
74e0fbce42 | |||
f080c56771 | |||
d2f968b902 | |||
e81601acf3 | |||
7073dc0d5d | |||
d090be60e8 | |||
4bad96d9d6 |
@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
import sqlite3
|
||||||
from invokeai.app.services.board_image_record_storage import (
|
from invokeai.app.services.board_image_record_storage import (
|
||||||
SqliteBoardImageRecordStorage,
|
SqliteBoardImageRecordStorage,
|
||||||
)
|
)
|
||||||
@ -28,6 +29,8 @@ from ..services.invoker import Invoker
|
|||||||
from ..services.processor import DefaultInvocationProcessor
|
from ..services.processor import DefaultInvocationProcessor
|
||||||
from ..services.sqlite import SqliteItemStorage
|
from ..services.sqlite import SqliteItemStorage
|
||||||
from ..services.model_manager_service import ModelManagerService
|
from ..services.model_manager_service import ModelManagerService
|
||||||
|
from ..services.batch_manager import BatchManager
|
||||||
|
from ..services.batch_manager_storage import SqliteBatchProcessStorage
|
||||||
from ..services.invocation_stats import InvocationStatsService
|
from ..services.invocation_stats import InvocationStatsService
|
||||||
from .events import FastAPIEventService
|
from .events import FastAPIEventService
|
||||||
|
|
||||||
@ -71,18 +74,18 @@ class ApiDependencies:
|
|||||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
db_location = str(db_path)
|
db_location = str(db_path)
|
||||||
|
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
|
||||||
filename=db_location, table_name="graph_executions"
|
|
||||||
)
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
||||||
|
|
||||||
urls = LocalUrlService()
|
urls = LocalUrlService()
|
||||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
image_record_storage = SqliteImageRecordStorage(conn=db_conn)
|
||||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
names = SimpleNameService()
|
names = SimpleNameService()
|
||||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||||
|
|
||||||
board_record_storage = SqliteBoardRecordStorage(db_location)
|
board_record_storage = SqliteBoardRecordStorage(conn=db_conn)
|
||||||
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn)
|
||||||
|
|
||||||
boards = BoardService(
|
boards = BoardService(
|
||||||
services=BoardServiceDependencies(
|
services=BoardServiceDependencies(
|
||||||
@ -116,15 +119,19 @@ class ApiDependencies:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
|
||||||
|
batch_manager = BatchManager(batch_manager_storage)
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=ModelManagerService(config, logger),
|
model_manager=ModelManagerService(config, logger),
|
||||||
events=events,
|
events=events,
|
||||||
latents=latents,
|
latents=latents,
|
||||||
images=images,
|
images=images,
|
||||||
|
batch_manager=batch_manager,
|
||||||
boards=boards,
|
boards=boards,
|
||||||
board_images=board_images,
|
board_images=board_images,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
configuration=config,
|
configuration=config,
|
||||||
|
106
invokeai/app/api/routers/batches.py
Normal file
106
invokeai/app/api/routers/batches.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from fastapi import Body, HTTPException, Path, Response
|
||||||
|
from fastapi.routing import APIRouter
|
||||||
|
|
||||||
|
from invokeai.app.services.batch_manager_storage import BatchSession, BatchSessionNotFoundException
|
||||||
|
|
||||||
|
# Importing * is bad karma but needed here for node detection
|
||||||
|
from ...invocations import * # noqa: F401 F403
|
||||||
|
from ...services.batch_manager import Batch, BatchProcessResponse
|
||||||
|
from ...services.graph import Graph
|
||||||
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
|
batches_router = APIRouter(prefix="/v1/batches", tags=["sessions"])
|
||||||
|
|
||||||
|
|
||||||
|
@batches_router.post(
|
||||||
|
"/",
|
||||||
|
operation_id="create_batch",
|
||||||
|
responses={
|
||||||
|
200: {"model": BatchProcessResponse},
|
||||||
|
400: {"description": "Invalid json"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def create_batch(
|
||||||
|
graph: Graph = Body(description="The graph to initialize the session with"),
|
||||||
|
batch: Batch = Body(description="Batch config to apply to the given graph"),
|
||||||
|
) -> BatchProcessResponse:
|
||||||
|
"""Creates a batch process"""
|
||||||
|
return ApiDependencies.invoker.services.batch_manager.create_batch_process(batch, graph)
|
||||||
|
|
||||||
|
|
||||||
|
@batches_router.put(
|
||||||
|
"/b/{batch_process_id}/invoke",
|
||||||
|
operation_id="start_batch",
|
||||||
|
responses={
|
||||||
|
202: {"description": "Batch process started"},
|
||||||
|
404: {"description": "Batch session not found"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def start_batch(
|
||||||
|
batch_process_id: str = Path(description="ID of Batch to start"),
|
||||||
|
) -> Response:
|
||||||
|
"""Executes a batch process"""
|
||||||
|
try:
|
||||||
|
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_process_id)
|
||||||
|
return Response(status_code=202)
|
||||||
|
except BatchSessionNotFoundException:
|
||||||
|
raise HTTPException(status_code=404, detail="Batch session not found")
|
||||||
|
|
||||||
|
|
||||||
|
@batches_router.delete(
|
||||||
|
"/b/{batch_process_id}",
|
||||||
|
operation_id="cancel_batch",
|
||||||
|
responses={202: {"description": "The batch is canceled"}},
|
||||||
|
)
|
||||||
|
async def cancel_batch(
|
||||||
|
batch_process_id: str = Path(description="The id of the batch process to cancel"),
|
||||||
|
) -> Response:
|
||||||
|
"""Cancels a batch process"""
|
||||||
|
ApiDependencies.invoker.services.batch_manager.cancel_batch_process(batch_process_id)
|
||||||
|
return Response(status_code=202)
|
||||||
|
|
||||||
|
|
||||||
|
@batches_router.get(
|
||||||
|
"/incomplete",
|
||||||
|
operation_id="list_incomplete_batches",
|
||||||
|
responses={200: {"model": list[BatchProcessResponse]}},
|
||||||
|
)
|
||||||
|
async def list_incomplete_batches() -> list[BatchProcessResponse]:
|
||||||
|
"""Lists incomplete batch processes"""
|
||||||
|
return ApiDependencies.invoker.services.batch_manager.get_incomplete_batch_processes()
|
||||||
|
|
||||||
|
|
||||||
|
@batches_router.get(
|
||||||
|
"/",
|
||||||
|
operation_id="list_batches",
|
||||||
|
responses={200: {"model": list[BatchProcessResponse]}},
|
||||||
|
)
|
||||||
|
async def list_batches() -> list[BatchProcessResponse]:
|
||||||
|
"""Lists all batch processes"""
|
||||||
|
return ApiDependencies.invoker.services.batch_manager.get_batch_processes()
|
||||||
|
|
||||||
|
|
||||||
|
@batches_router.get(
|
||||||
|
"/b/{batch_process_id}",
|
||||||
|
operation_id="get_batch",
|
||||||
|
responses={200: {"model": BatchProcessResponse}},
|
||||||
|
)
|
||||||
|
async def get_batch(
|
||||||
|
batch_process_id: str = Path(description="The id of the batch process to get"),
|
||||||
|
) -> BatchProcessResponse:
|
||||||
|
"""Gets a Batch Process"""
|
||||||
|
return ApiDependencies.invoker.services.batch_manager.get_batch(batch_process_id)
|
||||||
|
|
||||||
|
|
||||||
|
@batches_router.get(
|
||||||
|
"/b/{batch_process_id}/sessions",
|
||||||
|
operation_id="get_batch_sessions",
|
||||||
|
responses={200: {"model": list[BatchSession]}},
|
||||||
|
)
|
||||||
|
async def get_batch_sessions(
|
||||||
|
batch_process_id: str = Path(description="The id of the batch process to get"),
|
||||||
|
) -> list[BatchSession]:
|
||||||
|
"""Gets a list of batch sessions for a given batch process"""
|
||||||
|
return ApiDependencies.invoker.services.batch_manager.get_sessions(batch_process_id)
|
@ -9,13 +9,7 @@ from pydantic.fields import Field
|
|||||||
# Importing * is bad karma but needed here for node detection
|
# Importing * is bad karma but needed here for node detection
|
||||||
from ...invocations import * # noqa: F401 F403
|
from ...invocations import * # noqa: F401 F403
|
||||||
from ...invocations.baseinvocation import BaseInvocation
|
from ...invocations.baseinvocation import BaseInvocation
|
||||||
from ...services.graph import (
|
from ...services.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
|
||||||
Edge,
|
|
||||||
EdgeConnection,
|
|
||||||
Graph,
|
|
||||||
GraphExecutionState,
|
|
||||||
NodeAlreadyExecutedError,
|
|
||||||
)
|
|
||||||
from ...services.item_storage import PaginatedResults
|
from ...services.item_storage import PaginatedResults
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
|
@ -13,11 +13,15 @@ class SocketIO:
|
|||||||
|
|
||||||
def __init__(self, app: FastAPI):
|
def __init__(self, app: FastAPI):
|
||||||
self.__sio = SocketManager(app=app)
|
self.__sio = SocketManager(app=app)
|
||||||
self.__sio.on("subscribe", handler=self._handle_sub)
|
|
||||||
self.__sio.on("unsubscribe", handler=self._handle_unsub)
|
|
||||||
|
|
||||||
|
self.__sio.on("subscribe_session", handler=self._handle_sub_session)
|
||||||
|
self.__sio.on("unsubscribe_session", handler=self._handle_unsub_session)
|
||||||
local_handler.register(event_name=EventServiceBase.session_event, _func=self._handle_session_event)
|
local_handler.register(event_name=EventServiceBase.session_event, _func=self._handle_session_event)
|
||||||
|
|
||||||
|
self.__sio.on("subscribe_batch", handler=self._handle_sub_batch)
|
||||||
|
self.__sio.on("unsubscribe_batch", handler=self._handle_unsub_batch)
|
||||||
|
local_handler.register(event_name=EventServiceBase.batch_event, _func=self._handle_batch_event)
|
||||||
|
|
||||||
async def _handle_session_event(self, event: Event):
|
async def _handle_session_event(self, event: Event):
|
||||||
await self.__sio.emit(
|
await self.__sio.emit(
|
||||||
event=event[1]["event"],
|
event=event[1]["event"],
|
||||||
@ -25,12 +29,25 @@ class SocketIO:
|
|||||||
room=event[1]["data"]["graph_execution_state_id"],
|
room=event[1]["data"]["graph_execution_state_id"],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_sub(self, sid, data, *args, **kwargs):
|
async def _handle_sub_session(self, sid, data, *args, **kwargs):
|
||||||
if "session" in data:
|
if "session" in data:
|
||||||
self.__sio.enter_room(sid, data["session"])
|
self.__sio.enter_room(sid, data["session"])
|
||||||
|
|
||||||
# @app.sio.on('unsubscribe')
|
async def _handle_unsub_session(self, sid, data, *args, **kwargs):
|
||||||
|
|
||||||
async def _handle_unsub(self, sid, data, *args, **kwargs):
|
|
||||||
if "session" in data:
|
if "session" in data:
|
||||||
self.__sio.leave_room(sid, data["session"])
|
self.__sio.leave_room(sid, data["session"])
|
||||||
|
|
||||||
|
async def _handle_batch_event(self, event: Event):
|
||||||
|
await self.__sio.emit(
|
||||||
|
event=event[1]["event"],
|
||||||
|
data=event[1]["data"],
|
||||||
|
room=event[1]["data"]["batch_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_sub_batch(self, sid, data, *args, **kwargs):
|
||||||
|
if "batch_id" in data:
|
||||||
|
self.__sio.enter_room(sid, data["batch_id"])
|
||||||
|
|
||||||
|
async def _handle_unsub_batch(self, sid, data, *args, **kwargs):
|
||||||
|
if "batch_id" in data:
|
||||||
|
self.__sio.enter_room(sid, data["batch_id"])
|
||||||
|
@ -24,7 +24,7 @@ import invokeai.frontend.web as web_dir
|
|||||||
import mimetypes
|
import mimetypes
|
||||||
|
|
||||||
from .api.dependencies import ApiDependencies
|
from .api.dependencies import ApiDependencies
|
||||||
from .api.routers import sessions, models, images, boards, board_images, app_info
|
from .api.routers import sessions, batches, models, images, boards, board_images, app_info
|
||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
|
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
|
||||||
|
|
||||||
@ -90,6 +90,8 @@ async def shutdown_event():
|
|||||||
|
|
||||||
app.include_router(sessions.session_router, prefix="/api")
|
app.include_router(sessions.session_router, prefix="/api")
|
||||||
|
|
||||||
|
app.include_router(batches.batches_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(models.models_router, prefix="/api")
|
app.include_router(models.models_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(images.images_router, prefix="/api")
|
app.include_router(images.images_router, prefix="/api")
|
||||||
|
@ -5,6 +5,7 @@ import re
|
|||||||
import shlex
|
import shlex
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import sqlite3
|
||||||
from typing import Union, get_type_hints, Optional
|
from typing import Union, get_type_hints, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
@ -29,6 +30,8 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
|||||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||||
from invokeai.app.services.resource_name import SimpleNameService
|
from invokeai.app.services.resource_name import SimpleNameService
|
||||||
from invokeai.app.services.urls import LocalUrlService
|
from invokeai.app.services.urls import LocalUrlService
|
||||||
|
from invokeai.app.services.batch_manager import BatchManager
|
||||||
|
from invokeai.app.services.batch_manager_storage import SqliteBatchProcessStorage
|
||||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||||
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
||||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
@ -252,19 +255,18 @@ def invoke_cli():
|
|||||||
db_location = config.db_path
|
db_location = config.db_path
|
||||||
db_location.parent.mkdir(parents=True, exist_ok=True)
|
db_location.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
|
||||||
logger.info(f'InvokeAI database location is "{db_location}"')
|
logger.info(f'InvokeAI database location is "{db_location}"')
|
||||||
|
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
||||||
filename=db_location, table_name="graph_executions"
|
|
||||||
)
|
|
||||||
|
|
||||||
urls = LocalUrlService()
|
urls = LocalUrlService()
|
||||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
image_record_storage = SqliteImageRecordStorage(conn=db_conn)
|
||||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
names = SimpleNameService()
|
names = SimpleNameService()
|
||||||
|
|
||||||
board_record_storage = SqliteBoardRecordStorage(db_location)
|
board_record_storage = SqliteBoardRecordStorage(conn=db_conn)
|
||||||
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn)
|
||||||
|
|
||||||
boards = BoardService(
|
boards = BoardService(
|
||||||
services=BoardServiceDependencies(
|
services=BoardServiceDependencies(
|
||||||
@ -298,15 +300,19 @@ def invoke_cli():
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
|
||||||
|
batch_manager = BatchManager(batch_manager_storage)
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=model_manager,
|
model_manager=model_manager,
|
||||||
events=events,
|
events=events,
|
||||||
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
|
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
|
||||||
images=images,
|
images=images,
|
||||||
boards=boards,
|
boards=boards,
|
||||||
|
batch_manager=batch_manager,
|
||||||
board_images=board_images,
|
board_images=board_images,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||||
|
214
invokeai/app/services/batch_manager.py
Normal file
214
invokeai/app/services/batch_manager.py
Normal file
@ -0,0 +1,214 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from itertools import product
|
||||||
|
from typing import Optional
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from fastapi_events.handlers.local import local_handler
|
||||||
|
from fastapi_events.typing import Event
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from invokeai.app.services.batch_manager_storage import (
|
||||||
|
Batch,
|
||||||
|
BatchProcess,
|
||||||
|
BatchProcessStorageBase,
|
||||||
|
BatchSession,
|
||||||
|
BatchSessionChanges,
|
||||||
|
BatchSessionNotFoundException,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.events import EventServiceBase
|
||||||
|
from invokeai.app.services.graph import Graph, GraphExecutionState
|
||||||
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
|
||||||
|
|
||||||
|
class BatchProcessResponse(BaseModel):
|
||||||
|
batch_id: str = Field(description="ID for the batch")
|
||||||
|
session_ids: list[str] = Field(description="List of session IDs created for this batch")
|
||||||
|
|
||||||
|
|
||||||
|
class BatchManagerBase(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def start(self, invoker: Invoker) -> None:
|
||||||
|
"""Starts the BatchManager service"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_batch_process(self, batch: Batch, graph: Graph) -> BatchProcessResponse:
|
||||||
|
"""Creates a batch process"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run_batch_process(self, batch_id: str) -> None:
|
||||||
|
"""Runs a batch process"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def cancel_batch_process(self, batch_process_id: str) -> None:
|
||||||
|
"""Cancels a batch process"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_batch(self, batch_id: str) -> BatchProcessResponse:
|
||||||
|
"""Gets a batch process"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_batch_processes(self) -> list[BatchProcessResponse]:
|
||||||
|
"""Gets all batch processes"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_incomplete_batch_processes(self) -> list[BatchProcessResponse]:
|
||||||
|
"""Gets all incomplete batch processes"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_sessions(self, batch_id: str) -> list[BatchSession]:
|
||||||
|
"""Gets the sessions associated with a batch"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BatchManager(BatchManagerBase):
|
||||||
|
"""Responsible for managing currently running and scheduled batch jobs"""
|
||||||
|
|
||||||
|
__invoker: Invoker
|
||||||
|
__batch_process_storage: BatchProcessStorageBase
|
||||||
|
|
||||||
|
def __init__(self, batch_process_storage: BatchProcessStorageBase) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.__batch_process_storage = batch_process_storage
|
||||||
|
|
||||||
|
def start(self, invoker: Invoker) -> None:
|
||||||
|
self.__invoker = invoker
|
||||||
|
local_handler.register(event_name=EventServiceBase.session_event, _func=self.on_event)
|
||||||
|
|
||||||
|
async def on_event(self, event: Event):
|
||||||
|
event_name = event[1]["event"]
|
||||||
|
|
||||||
|
match event_name:
|
||||||
|
case "graph_execution_state_complete":
|
||||||
|
await self._process(event, False)
|
||||||
|
case "invocation_error":
|
||||||
|
await self._process(event, True)
|
||||||
|
|
||||||
|
return event
|
||||||
|
|
||||||
|
async def _process(self, event: Event, err: bool) -> None:
|
||||||
|
data = event[1]["data"]
|
||||||
|
try:
|
||||||
|
batch_session = self.__batch_process_storage.get_session_by_session_id(data["graph_execution_state_id"])
|
||||||
|
except BatchSessionNotFoundException:
|
||||||
|
return None
|
||||||
|
changes = BatchSessionChanges(state="error" if err else "completed")
|
||||||
|
batch_session = self.__batch_process_storage.update_session_state(
|
||||||
|
batch_session.batch_id,
|
||||||
|
batch_session.session_id,
|
||||||
|
changes,
|
||||||
|
)
|
||||||
|
batch_process = self.__batch_process_storage.get(batch_session.batch_id)
|
||||||
|
if not batch_process.canceled:
|
||||||
|
self.run_batch_process(batch_process.batch_id)
|
||||||
|
|
||||||
|
def _create_graph_execution_state(
|
||||||
|
self, batch_process: BatchProcess, batch_indices: tuple[int, ...]
|
||||||
|
) -> GraphExecutionState:
|
||||||
|
graph = batch_process.graph.copy(deep=True)
|
||||||
|
batch = batch_process.batch
|
||||||
|
for index, bdl in enumerate(batch.data):
|
||||||
|
for bd in bdl:
|
||||||
|
node = graph.get_node(bd.node_path)
|
||||||
|
if node is None:
|
||||||
|
continue
|
||||||
|
batch_index = batch_indices[index]
|
||||||
|
datum = bd.items[batch_index]
|
||||||
|
key = bd.field_name
|
||||||
|
node.__dict__[key] = datum
|
||||||
|
graph.update_node(bd.node_path, node)
|
||||||
|
|
||||||
|
return GraphExecutionState(graph=graph)
|
||||||
|
|
||||||
|
def run_batch_process(self, batch_id: str) -> None:
|
||||||
|
self.__batch_process_storage.start(batch_id)
|
||||||
|
batch_process = self.__batch_process_storage.get(batch_id)
|
||||||
|
next_batch_index = self._get_batch_index_tuple(batch_process)
|
||||||
|
if next_batch_index is None:
|
||||||
|
# finished with current run
|
||||||
|
if batch_process.current_run >= (batch_process.batch.runs - 1):
|
||||||
|
# finished with all runs
|
||||||
|
return
|
||||||
|
batch_process.current_batch_index = 0
|
||||||
|
batch_process.current_run += 1
|
||||||
|
next_batch_index = self._get_batch_index_tuple(batch_process)
|
||||||
|
if next_batch_index is None:
|
||||||
|
# shouldn't happen; satisfy types
|
||||||
|
return
|
||||||
|
# remember to increment the batch index
|
||||||
|
batch_process.current_batch_index += 1
|
||||||
|
self.__batch_process_storage.save(batch_process)
|
||||||
|
ges = self._create_graph_execution_state(batch_process=batch_process, batch_indices=next_batch_index)
|
||||||
|
next_session = self.__batch_process_storage.create_session(
|
||||||
|
BatchSession(
|
||||||
|
batch_id=batch_id,
|
||||||
|
session_id=str(uuid4()),
|
||||||
|
state="uninitialized",
|
||||||
|
batch_index=batch_process.current_batch_index,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ges.id = next_session.session_id
|
||||||
|
self.__invoker.services.graph_execution_manager.set(ges)
|
||||||
|
self.__batch_process_storage.update_session_state(
|
||||||
|
batch_id=next_session.batch_id,
|
||||||
|
session_id=next_session.session_id,
|
||||||
|
changes=BatchSessionChanges(state="in_progress"),
|
||||||
|
)
|
||||||
|
self.__invoker.services.events.emit_batch_session_created(next_session.batch_id, next_session.session_id)
|
||||||
|
self.__invoker.invoke(ges, invoke_all=True)
|
||||||
|
|
||||||
|
def create_batch_process(self, batch: Batch, graph: Graph) -> BatchProcessResponse:
|
||||||
|
batch_process = BatchProcess(
|
||||||
|
batch=batch,
|
||||||
|
graph=graph,
|
||||||
|
)
|
||||||
|
batch_process = self.__batch_process_storage.save(batch_process)
|
||||||
|
return BatchProcessResponse(
|
||||||
|
batch_id=batch_process.batch_id,
|
||||||
|
session_ids=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_sessions(self, batch_id: str) -> list[BatchSession]:
|
||||||
|
return self.__batch_process_storage.get_sessions_by_batch_id(batch_id)
|
||||||
|
|
||||||
|
def get_batch(self, batch_id: str) -> BatchProcess:
|
||||||
|
return self.__batch_process_storage.get(batch_id)
|
||||||
|
|
||||||
|
def get_batch_processes(self) -> list[BatchProcessResponse]:
|
||||||
|
bps = self.__batch_process_storage.get_all()
|
||||||
|
return self._get_batch_process_responses(bps)
|
||||||
|
|
||||||
|
def get_incomplete_batch_processes(self) -> list[BatchProcessResponse]:
|
||||||
|
bps = self.__batch_process_storage.get_incomplete()
|
||||||
|
return self._get_batch_process_responses(bps)
|
||||||
|
|
||||||
|
def cancel_batch_process(self, batch_process_id: str) -> None:
|
||||||
|
self.__batch_process_storage.cancel(batch_process_id)
|
||||||
|
|
||||||
|
def _get_batch_process_responses(self, batch_processes: list[BatchProcess]) -> list[BatchProcessResponse]:
|
||||||
|
sessions = list()
|
||||||
|
res: list[BatchProcessResponse] = list()
|
||||||
|
for bp in batch_processes:
|
||||||
|
sessions = self.__batch_process_storage.get_sessions_by_batch_id(bp.batch_id)
|
||||||
|
res.append(
|
||||||
|
BatchProcessResponse(
|
||||||
|
batch_id=bp.batch_id,
|
||||||
|
session_ids=[session.session_id for session in sessions],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _get_batch_index_tuple(self, batch_process: BatchProcess) -> Optional[tuple[int, ...]]:
|
||||||
|
batch_indices = list()
|
||||||
|
for batchdata in batch_process.batch.data:
|
||||||
|
batch_indices.append(list(range(len(batchdata[0].items))))
|
||||||
|
try:
|
||||||
|
return list(product(*batch_indices))[batch_process.current_batch_index]
|
||||||
|
except IndexError:
|
||||||
|
return None
|
707
invokeai/app/services/batch_manager_storage.py
Normal file
707
invokeai/app/services/batch_manager_storage.py
Normal file
@ -0,0 +1,707 @@
|
|||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
import uuid
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Literal, Optional, Union, cast
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr, parse_raw_as, validator
|
||||||
|
|
||||||
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
|
from invokeai.app.services.graph import Graph
|
||||||
|
|
||||||
|
BatchDataType = Union[StrictStr, StrictInt, StrictFloat, ImageField]
|
||||||
|
|
||||||
|
|
||||||
|
class BatchData(BaseModel):
|
||||||
|
"""
|
||||||
|
A batch data collection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
node_path: str = Field(description="The node into which this batch data collection will be substituted.")
|
||||||
|
field_name: str = Field(description="The field into which this batch data collection will be substituted.")
|
||||||
|
items: list[BatchDataType] = Field(
|
||||||
|
default_factory=list, description="The list of items to substitute into the node/field."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Batch(BaseModel):
|
||||||
|
"""
|
||||||
|
A batch, consisting of a list of a list of batch data collections.
|
||||||
|
|
||||||
|
First, each inner list[BatchData] is zipped into a single batch data collection.
|
||||||
|
|
||||||
|
Then, the final batch collection is created by taking the Cartesian product of all batch data collections.
|
||||||
|
"""
|
||||||
|
|
||||||
|
data: list[list[BatchData]] = Field(default_factory=list, description="The list of batch data collections.")
|
||||||
|
runs: int = Field(default=1, description="Int stating how many times to iterate through all possible batch indices")
|
||||||
|
|
||||||
|
@validator("runs")
|
||||||
|
def validate_positive_runs(cls, r: int):
|
||||||
|
if r < 1:
|
||||||
|
raise ValueError("runs must be a positive integer")
|
||||||
|
return r
|
||||||
|
|
||||||
|
@validator("data")
|
||||||
|
def validate_len(cls, v: list[list[BatchData]]):
|
||||||
|
for batch_data in v:
|
||||||
|
if any(len(batch_data[0].items) != len(i.items) for i in batch_data):
|
||||||
|
raise ValueError("Zipped batch items must have all have same length")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("data")
|
||||||
|
def validate_types(cls, v: list[list[BatchData]]):
|
||||||
|
for batch_data in v:
|
||||||
|
for datum in batch_data:
|
||||||
|
for item in datum.items:
|
||||||
|
if not all(isinstance(item, type(i)) for i in datum.items):
|
||||||
|
raise TypeError("All items in a batch must have have same type")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("data")
|
||||||
|
def validate_unique_field_mappings(cls, v: list[list[BatchData]]):
|
||||||
|
paths: set[tuple[str, str]] = set()
|
||||||
|
count: int = 0
|
||||||
|
for batch_data in v:
|
||||||
|
for datum in batch_data:
|
||||||
|
paths.add((datum.node_path, datum.field_name))
|
||||||
|
count += 1
|
||||||
|
if len(paths) != count:
|
||||||
|
raise ValueError("Each batch data must have unique node_id and field_name")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
def uuid_string():
|
||||||
|
res = uuid.uuid4()
|
||||||
|
return str(res)
|
||||||
|
|
||||||
|
|
||||||
|
BATCH_SESSION_STATE = Literal["uninitialized", "in_progress", "completed", "error"]
|
||||||
|
|
||||||
|
|
||||||
|
class BatchSession(BaseModel):
|
||||||
|
batch_id: str = Field(defaultdescription="The Batch to which this BatchSession is attached.")
|
||||||
|
session_id: str = Field(
|
||||||
|
default_factory=uuid_string, description="The Session to which this BatchSession is attached."
|
||||||
|
)
|
||||||
|
batch_index: int = Field(description="The index of this batch session in its parent batch process")
|
||||||
|
state: BATCH_SESSION_STATE = Field(default="uninitialized", description="The state of this BatchSession")
|
||||||
|
|
||||||
|
|
||||||
|
class BatchProcess(BaseModel):
|
||||||
|
batch_id: str = Field(default_factory=uuid_string, description="Identifier for this batch.")
|
||||||
|
batch: Batch = Field(description="The Batch to apply to this session.")
|
||||||
|
current_batch_index: int = Field(default=0, description="The last executed batch index")
|
||||||
|
current_run: int = Field(default=0, description="The current run of the batch")
|
||||||
|
canceled: bool = Field(description="Whether or not to run sessions from this batch.", default=False)
|
||||||
|
graph: Graph = Field(description="The graph into which batch data will be inserted before being executed.")
|
||||||
|
|
||||||
|
|
||||||
|
class BatchSessionChanges(BaseModel, extra=Extra.forbid):
|
||||||
|
state: BATCH_SESSION_STATE = Field(description="The state of this BatchSession")
|
||||||
|
|
||||||
|
|
||||||
|
class BatchProcessNotFoundException(Exception):
|
||||||
|
"""Raised when an Batch Process record is not found."""
|
||||||
|
|
||||||
|
def __init__(self, message="BatchProcess record not found"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchProcessSaveException(Exception):
|
||||||
|
"""Raised when an Batch Process record cannot be saved."""
|
||||||
|
|
||||||
|
def __init__(self, message="BatchProcess record not saved"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchProcessDeleteException(Exception):
|
||||||
|
"""Raised when an Batch Process record cannot be deleted."""
|
||||||
|
|
||||||
|
def __init__(self, message="BatchProcess record not deleted"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchSessionNotFoundException(Exception):
|
||||||
|
"""Raised when an Batch Session record is not found."""
|
||||||
|
|
||||||
|
def __init__(self, message="BatchSession record not found"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchSessionSaveException(Exception):
|
||||||
|
"""Raised when an Batch Session record cannot be saved."""
|
||||||
|
|
||||||
|
def __init__(self, message="BatchSession record not saved"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchSessionDeleteException(Exception):
|
||||||
|
"""Raised when an Batch Session record cannot be deleted."""
|
||||||
|
|
||||||
|
def __init__(self, message="BatchSession record not deleted"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchProcessStorageBase(ABC):
|
||||||
|
"""Low-level service responsible for interfacing with the Batch Process record store."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, batch_id: str) -> None:
|
||||||
|
"""Deletes a BatchProcess record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
batch_process: BatchProcess,
|
||||||
|
) -> BatchProcess:
|
||||||
|
"""Saves a BatchProcess record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(
|
||||||
|
self,
|
||||||
|
batch_id: str,
|
||||||
|
) -> BatchProcess:
|
||||||
|
"""Gets a BatchProcess record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_all(
|
||||||
|
self,
|
||||||
|
) -> list[BatchProcess]:
|
||||||
|
"""Gets a BatchProcess record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_incomplete(
|
||||||
|
self,
|
||||||
|
) -> list[BatchProcess]:
|
||||||
|
"""Gets a BatchProcess record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def start(
|
||||||
|
self,
|
||||||
|
batch_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""'Starts' a BatchProcess record by marking its `canceled` attribute to False."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def cancel(
|
||||||
|
self,
|
||||||
|
batch_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""'Cancels' a BatchProcess record by setting its `canceled` attribute to True."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_session(
|
||||||
|
self,
|
||||||
|
session: BatchSession,
|
||||||
|
) -> BatchSession:
|
||||||
|
"""Creates a BatchSession attached to a BatchProcess."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_sessions(
|
||||||
|
self,
|
||||||
|
sessions: list[BatchSession],
|
||||||
|
) -> list[BatchSession]:
|
||||||
|
"""Creates many BatchSessions attached to a BatchProcess."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_session_by_session_id(self, session_id: str) -> BatchSession:
|
||||||
|
"""Gets a BatchSession by session_id"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_sessions_by_batch_id(self, batch_id: str) -> List[BatchSession]:
|
||||||
|
"""Gets all BatchSession's for a given BatchProcess id."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_sessions_by_session_ids(self, session_ids: list[str]) -> List[BatchSession]:
|
||||||
|
"""Gets all BatchSession's for a given list of session ids."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_next_session(self, batch_id: str) -> BatchSession:
|
||||||
|
"""Gets the next BatchSession with state `uninitialized`, for a given BatchProcess id."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_session_state(
|
||||||
|
self,
|
||||||
|
batch_id: str,
|
||||||
|
session_id: str,
|
||||||
|
changes: BatchSessionChanges,
|
||||||
|
) -> BatchSession:
|
||||||
|
"""Updates the state of a BatchSession record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||||
|
_conn: sqlite3.Connection
|
||||||
|
_cursor: sqlite3.Cursor
|
||||||
|
_lock: threading.Lock
|
||||||
|
|
||||||
|
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._conn = conn
|
||||||
|
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||||
|
self._conn.row_factory = sqlite3.Row
|
||||||
|
self._cursor = self._conn.cursor()
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
# Enable foreign keys
|
||||||
|
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||||
|
self._create_tables()
|
||||||
|
self._conn.commit()
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def _create_tables(self) -> None:
|
||||||
|
"""Creates the `batch_process` table and `batch_session` junction table."""
|
||||||
|
|
||||||
|
# Create the `batch_process` table.
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE TABLE IF NOT EXISTS batch_process (
|
||||||
|
batch_id TEXT NOT NULL PRIMARY KEY,
|
||||||
|
batch TEXT NOT NULL,
|
||||||
|
graph TEXT NOT NULL,
|
||||||
|
current_batch_index NUMBER NOT NULL,
|
||||||
|
current_run NUMBER NOT NULL,
|
||||||
|
canceled BOOLEAN NOT NULL DEFAULT(0),
|
||||||
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- Updated via trigger
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- Soft delete, currently unused
|
||||||
|
deleted_at DATETIME
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_batch_process_created_at ON batch_process (created_at);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add trigger for `updated_at`.
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE TRIGGER IF NOT EXISTS tg_batch_process_updated_at
|
||||||
|
AFTER UPDATE
|
||||||
|
ON batch_process FOR EACH ROW
|
||||||
|
BEGIN
|
||||||
|
UPDATE batch_process SET updated_at = current_timestamp
|
||||||
|
WHERE batch_id = old.batch_id;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the `batch_session` junction table.
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE TABLE IF NOT EXISTS batch_session (
|
||||||
|
batch_id TEXT NOT NULL,
|
||||||
|
session_id TEXT NOT NULL,
|
||||||
|
state TEXT NOT NULL,
|
||||||
|
batch_index NUMBER NOT NULL,
|
||||||
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- updated via trigger
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- Soft delete, currently unused
|
||||||
|
deleted_at DATETIME,
|
||||||
|
-- enforce one-to-many relationship between batch_process and batch_session using PK
|
||||||
|
-- (we can extend this to many-to-many later)
|
||||||
|
PRIMARY KEY (batch_id,session_id),
|
||||||
|
FOREIGN KEY (batch_id) REFERENCES batch_process (batch_id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add index for batch id
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id ON batch_session (batch_id);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add index for batch id, sorted by created_at
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id_created_at ON batch_session (batch_id,created_at);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add trigger for `updated_at`.
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE TRIGGER IF NOT EXISTS tg_batch_session_updated_at
|
||||||
|
AFTER UPDATE
|
||||||
|
ON batch_session FOR EACH ROW
|
||||||
|
BEGIN
|
||||||
|
UPDATE batch_session SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||||
|
WHERE batch_id = old.batch_id AND session_id = old.session_id;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete(self, batch_id: str) -> None:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
DELETE FROM batch_process
|
||||||
|
WHERE batch_id = ?;
|
||||||
|
""",
|
||||||
|
(batch_id,),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchProcessDeleteException from e
|
||||||
|
except Exception as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchProcessDeleteException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
batch_process: BatchProcess,
|
||||||
|
) -> BatchProcess:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
INSERT OR REPLACE INTO batch_process (batch_id, batch, graph, current_batch_index, current_run)
|
||||||
|
VALUES (?, ?, ?, ?, ?);
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
batch_process.batch_id,
|
||||||
|
batch_process.batch.json(),
|
||||||
|
batch_process.graph.json(),
|
||||||
|
batch_process.current_batch_index,
|
||||||
|
batch_process.current_run,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchProcessSaveException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
return self.get(batch_process.batch_id)
|
||||||
|
|
||||||
|
def _deserialize_batch_process(self, session_dict: dict) -> BatchProcess:
|
||||||
|
"""Deserializes a batch session."""
|
||||||
|
|
||||||
|
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||||
|
|
||||||
|
batch_id = session_dict.get("batch_id", "unknown")
|
||||||
|
batch_raw = session_dict.get("batch", "unknown")
|
||||||
|
graph_raw = session_dict.get("graph", "unknown")
|
||||||
|
current_batch_index = session_dict.get("current_batch_index", 0)
|
||||||
|
current_run = session_dict.get("current_run", 0)
|
||||||
|
canceled = session_dict.get("canceled", 0)
|
||||||
|
return BatchProcess(
|
||||||
|
batch_id=batch_id,
|
||||||
|
batch=parse_raw_as(Batch, batch_raw),
|
||||||
|
graph=parse_raw_as(Graph, graph_raw),
|
||||||
|
current_batch_index=current_batch_index,
|
||||||
|
current_run=current_run,
|
||||||
|
canceled=canceled == 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self,
|
||||||
|
batch_id: str,
|
||||||
|
) -> BatchProcess:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT *
|
||||||
|
FROM batch_process
|
||||||
|
WHERE batch_id = ?;
|
||||||
|
""",
|
||||||
|
(batch_id,),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchProcessNotFoundException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
if result is None:
|
||||||
|
raise BatchProcessNotFoundException
|
||||||
|
return self._deserialize_batch_process(dict(result))
|
||||||
|
|
||||||
|
def get_all(
|
||||||
|
self,
|
||||||
|
) -> list[BatchProcess]:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT *
|
||||||
|
FROM batch_process
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchProcessNotFoundException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
if result is None:
|
||||||
|
return list()
|
||||||
|
return list(map(lambda r: self._deserialize_batch_process(dict(r)), result))
|
||||||
|
|
||||||
|
def get_incomplete(
|
||||||
|
self,
|
||||||
|
) -> list[BatchProcess]:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT bp.*
|
||||||
|
FROM batch_process bp
|
||||||
|
WHERE bp.batch_id IN
|
||||||
|
(
|
||||||
|
SELECT batch_id
|
||||||
|
FROM batch_session bs
|
||||||
|
WHERE state IN ('uninitialized', 'in_progress')
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchProcessNotFoundException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
if result is None:
|
||||||
|
return list()
|
||||||
|
return list(map(lambda r: self._deserialize_batch_process(dict(r)), result))
|
||||||
|
|
||||||
|
def start(
|
||||||
|
self,
|
||||||
|
batch_id: str,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
UPDATE batch_process
|
||||||
|
SET canceled = 0
|
||||||
|
WHERE batch_id = ?;
|
||||||
|
""",
|
||||||
|
(batch_id,),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchSessionSaveException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def cancel(
|
||||||
|
self,
|
||||||
|
batch_id: str,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
UPDATE batch_process
|
||||||
|
SET canceled = 1
|
||||||
|
WHERE batch_id = ?;
|
||||||
|
""",
|
||||||
|
(batch_id,),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchSessionSaveException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def create_session(
|
||||||
|
self,
|
||||||
|
session: BatchSession,
|
||||||
|
) -> BatchSession:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
|
||||||
|
VALUES (?, ?, ?, ?);
|
||||||
|
""",
|
||||||
|
(session.batch_id, session.session_id, session.state, session.batch_index),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchSessionSaveException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
return self.get_session_by_session_id(session.session_id)
|
||||||
|
|
||||||
|
def create_sessions(
|
||||||
|
self,
|
||||||
|
sessions: list[BatchSession],
|
||||||
|
) -> list[BatchSession]:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
session_data = [(session.batch_id, session.session_id, session.state) for session in sessions]
|
||||||
|
self._cursor.executemany(
|
||||||
|
"""--sql
|
||||||
|
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
|
||||||
|
VALUES (?, ?, ?, ?);
|
||||||
|
""",
|
||||||
|
session_data,
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchSessionSaveException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
return self.get_sessions_by_session_ids([session.session_id for session in sessions])
|
||||||
|
|
||||||
|
def get_session_by_session_id(self, session_id: str) -> BatchSession:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT *
|
||||||
|
FROM batch_session
|
||||||
|
WHERE session_id= ?;
|
||||||
|
""",
|
||||||
|
(session_id,),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchSessionNotFoundException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
if result is None:
|
||||||
|
raise BatchSessionNotFoundException
|
||||||
|
return self._deserialize_batch_session(dict(result))
|
||||||
|
|
||||||
|
def _deserialize_batch_session(self, session_dict: dict) -> BatchSession:
|
||||||
|
"""Deserializes a batch session."""
|
||||||
|
|
||||||
|
return BatchSession.parse_obj(session_dict)
|
||||||
|
|
||||||
|
def get_next_session(self, batch_id: str) -> BatchSession:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT *
|
||||||
|
FROM batch_session
|
||||||
|
WHERE batch_id = ? AND state = 'uninitialized';
|
||||||
|
""",
|
||||||
|
(batch_id,),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchSessionNotFoundException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
if result is None:
|
||||||
|
raise BatchSessionNotFoundException
|
||||||
|
session = self._deserialize_batch_session(dict(result))
|
||||||
|
return session
|
||||||
|
|
||||||
|
def get_sessions_by_batch_id(self, batch_id: str) -> List[BatchSession]:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT *
|
||||||
|
FROM batch_session
|
||||||
|
WHERE batch_id = ?;
|
||||||
|
""",
|
||||||
|
(batch_id,),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchSessionNotFoundException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
if result is None:
|
||||||
|
raise BatchSessionNotFoundException
|
||||||
|
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
|
||||||
|
return sessions
|
||||||
|
|
||||||
|
def get_sessions_by_session_ids(self, session_ids: list[str]) -> List[BatchSession]:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
placeholders = ",".join("?" * len(session_ids))
|
||||||
|
self._cursor.execute(
|
||||||
|
f"""--sql
|
||||||
|
SELECT * FROM batch_session
|
||||||
|
WHERE session_id
|
||||||
|
IN ({placeholders})
|
||||||
|
""",
|
||||||
|
tuple(session_ids),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchSessionNotFoundException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
if result is None:
|
||||||
|
raise BatchSessionNotFoundException
|
||||||
|
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
|
||||||
|
return sessions
|
||||||
|
|
||||||
|
def update_session_state(
|
||||||
|
self,
|
||||||
|
batch_id: str,
|
||||||
|
session_id: str,
|
||||||
|
changes: BatchSessionChanges,
|
||||||
|
) -> BatchSession:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
|
||||||
|
# Change the state of a batch session
|
||||||
|
if changes.state is not None:
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
UPDATE batch_session
|
||||||
|
SET state = ?
|
||||||
|
WHERE batch_id = ? AND session_id = ?;
|
||||||
|
""",
|
||||||
|
(changes.state, batch_id, session_id),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchSessionSaveException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
return self.get_session_by_session_id(session_id)
|
@ -56,15 +56,13 @@ class BoardImageRecordStorageBase(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||||
_filename: str
|
|
||||||
_conn: sqlite3.Connection
|
_conn: sqlite3.Connection
|
||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_lock: threading.Lock
|
_lock: threading.Lock
|
||||||
|
|
||||||
def __init__(self, filename: str) -> None:
|
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._filename = filename
|
self._conn = conn
|
||||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
|
||||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||||
self._conn.row_factory = sqlite3.Row
|
self._conn.row_factory = sqlite3.Row
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
|
@ -89,15 +89,13 @@ class BoardRecordStorageBase(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||||
_filename: str
|
|
||||||
_conn: sqlite3.Connection
|
_conn: sqlite3.Connection
|
||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_lock: threading.Lock
|
_lock: threading.Lock
|
||||||
|
|
||||||
def __init__(self, filename: str) -> None:
|
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._filename = filename
|
self._conn = conn
|
||||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
|
||||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||||
self._conn.row_factory = sqlite3.Row
|
self._conn.row_factory = sqlite3.Row
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
|
@ -13,6 +13,7 @@ from invokeai.app.services.model_manager_service import (
|
|||||||
|
|
||||||
class EventServiceBase:
|
class EventServiceBase:
|
||||||
session_event: str = "session_event"
|
session_event: str = "session_event"
|
||||||
|
batch_event: str = "batch_event"
|
||||||
|
|
||||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||||
|
|
||||||
@ -20,12 +21,21 @@ class EventServiceBase:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def __emit_session_event(self, event_name: str, payload: dict) -> None:
|
def __emit_session_event(self, event_name: str, payload: dict) -> None:
|
||||||
|
"""Session events are emitted to a room with the session_id as the room name"""
|
||||||
payload["timestamp"] = get_timestamp()
|
payload["timestamp"] = get_timestamp()
|
||||||
self.dispatch(
|
self.dispatch(
|
||||||
event_name=EventServiceBase.session_event,
|
event_name=EventServiceBase.session_event,
|
||||||
payload=dict(event=event_name, data=payload),
|
payload=dict(event=event_name, data=payload),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __emit_batch_event(self, event_name: str, payload: dict) -> None:
|
||||||
|
"""Batch events are emitted to a room with the batch_id as the room name"""
|
||||||
|
payload["timestamp"] = get_timestamp()
|
||||||
|
self.dispatch(
|
||||||
|
event_name=EventServiceBase.batch_event,
|
||||||
|
payload=dict(event=event_name, data=payload),
|
||||||
|
)
|
||||||
|
|
||||||
# Define events here for every event in the system.
|
# Define events here for every event in the system.
|
||||||
# This will make them easier to integrate until we find a schema generator.
|
# This will make them easier to integrate until we find a schema generator.
|
||||||
def emit_generator_progress(
|
def emit_generator_progress(
|
||||||
@ -187,3 +197,14 @@ class EventServiceBase:
|
|||||||
error=error,
|
error=error,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def emit_batch_session_created(
|
||||||
|
self,
|
||||||
|
batch_id: str,
|
||||||
|
graph_execution_state_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Emitted when a batch session is created"""
|
||||||
|
self.__emit_batch_event(
|
||||||
|
event_name="batch_session_created",
|
||||||
|
payload=dict(batch_id=batch_id, graph_execution_state_id=graph_execution_state_id),
|
||||||
|
)
|
||||||
|
@ -152,15 +152,13 @@ class ImageRecordStorageBase(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||||
_filename: str
|
|
||||||
_conn: sqlite3.Connection
|
_conn: sqlite3.Connection
|
||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_lock: threading.Lock
|
_lock: threading.Lock
|
||||||
|
|
||||||
def __init__(self, filename: str) -> None:
|
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._filename = filename
|
self._conn = conn
|
||||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
|
||||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||||
self._conn.row_factory = sqlite3.Row
|
self._conn.row_factory = sqlite3.Row
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
|
@ -4,6 +4,7 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
from invokeai.app.services.batch_manager import BatchManagerBase
|
||||||
from invokeai.app.services.board_images import BoardImagesServiceABC
|
from invokeai.app.services.board_images import BoardImagesServiceABC
|
||||||
from invokeai.app.services.boards import BoardServiceABC
|
from invokeai.app.services.boards import BoardServiceABC
|
||||||
from invokeai.app.services.images import ImageServiceABC
|
from invokeai.app.services.images import ImageServiceABC
|
||||||
@ -22,6 +23,7 @@ class InvocationServices:
|
|||||||
"""Services that can be used by invocations"""
|
"""Services that can be used by invocations"""
|
||||||
|
|
||||||
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
||||||
|
batch_manager: "BatchManagerBase"
|
||||||
board_images: "BoardImagesServiceABC"
|
board_images: "BoardImagesServiceABC"
|
||||||
boards: "BoardServiceABC"
|
boards: "BoardServiceABC"
|
||||||
configuration: "InvokeAIAppConfig"
|
configuration: "InvokeAIAppConfig"
|
||||||
@ -38,6 +40,7 @@ class InvocationServices:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
batch_manager: "BatchManagerBase",
|
||||||
board_images: "BoardImagesServiceABC",
|
board_images: "BoardImagesServiceABC",
|
||||||
boards: "BoardServiceABC",
|
boards: "BoardServiceABC",
|
||||||
configuration: "InvokeAIAppConfig",
|
configuration: "InvokeAIAppConfig",
|
||||||
@ -52,6 +55,7 @@ class InvocationServices:
|
|||||||
performance_statistics: "InvocationStatsServiceBase",
|
performance_statistics: "InvocationStatsServiceBase",
|
||||||
queue: "InvocationQueueABC",
|
queue: "InvocationQueueABC",
|
||||||
):
|
):
|
||||||
|
self.batch_manager = batch_manager
|
||||||
self.board_images = board_images
|
self.board_images = board_images
|
||||||
self.boards = boards
|
self.boards = boards
|
||||||
self.boards = boards
|
self.boards = boards
|
||||||
|
@ -12,23 +12,19 @@ sqlite_memory = ":memory:"
|
|||||||
|
|
||||||
|
|
||||||
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||||
_filename: str
|
|
||||||
_table_name: str
|
_table_name: str
|
||||||
_conn: sqlite3.Connection
|
_conn: sqlite3.Connection
|
||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_id_field: str
|
_id_field: str
|
||||||
_lock: Lock
|
_lock: Lock
|
||||||
|
|
||||||
def __init__(self, filename: str, table_name: str, id_field: str = "id"):
|
def __init__(self, conn: sqlite3.Connection, table_name: str, id_field: str = "id"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self._filename = filename
|
|
||||||
self._table_name = table_name
|
self._table_name = table_name
|
||||||
self._id_field = id_field # TODO: validate that T has this field
|
self._id_field = id_field # TODO: validate that T has this field
|
||||||
self._lock = Lock()
|
self._lock = Lock()
|
||||||
self._conn = sqlite3.connect(
|
self._conn = conn
|
||||||
self._filename, check_same_thread=False
|
|
||||||
) # TODO: figure out a better threading solution
|
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
|
|
||||||
self._create_table()
|
self._create_table()
|
||||||
@ -49,8 +45,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
|
|
||||||
def _parse_item(self, item: str) -> T:
|
def _parse_item(self, item: str) -> T:
|
||||||
item_type = get_args(self.__orig_class__)[0]
|
item_type = get_args(self.__orig_class__)[0]
|
||||||
parsed = parse_raw_as(item_type, item)
|
return parse_raw_as(item_type, item)
|
||||||
return parsed
|
|
||||||
|
|
||||||
def set(self, item: T):
|
def set(self, item: T):
|
||||||
try:
|
try:
|
||||||
|
344
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
344
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
File diff suppressed because one or more lines are too long
@ -25,6 +25,7 @@ from invokeai.app.services.graph import (
|
|||||||
LibraryGraph,
|
LibraryGraph,
|
||||||
)
|
)
|
||||||
import pytest
|
import pytest
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -42,9 +43,8 @@ def simple_graph():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_services() -> InvocationServices:
|
def mock_services() -> InvocationServices:
|
||||||
# NOTE: none of these are actually called by the test invocations
|
# NOTE: none of these are actually called by the test invocations
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||||
filename=sqlite_memory, table_name="graph_executions"
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
||||||
)
|
|
||||||
return InvocationServices(
|
return InvocationServices(
|
||||||
model_manager=None, # type: ignore
|
model_manager=None, # type: ignore
|
||||||
events=TestEventService(),
|
events=TestEventService(),
|
||||||
@ -52,9 +52,10 @@ def mock_services() -> InvocationServices:
|
|||||||
images=None, # type: ignore
|
images=None, # type: ignore
|
||||||
latents=None, # type: ignore
|
latents=None, # type: ignore
|
||||||
boards=None, # type: ignore
|
boards=None, # type: ignore
|
||||||
|
batch_manager=None, # type: ignore
|
||||||
board_images=None, # type: ignore
|
board_images=None, # type: ignore
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"),
|
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
|
@ -12,12 +12,19 @@ from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
|||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||||
|
from invokeai.app.services.batch_manager_storage import BatchData, SqliteBatchProcessStorage
|
||||||
|
from invokeai.app.services.batch_manager import (
|
||||||
|
Batch,
|
||||||
|
BatchManager,
|
||||||
|
)
|
||||||
from invokeai.app.services.graph import (
|
from invokeai.app.services.graph import (
|
||||||
Graph,
|
Graph,
|
||||||
GraphExecutionState,
|
GraphExecutionState,
|
||||||
|
GraphInvocation,
|
||||||
LibraryGraph,
|
LibraryGraph,
|
||||||
)
|
)
|
||||||
import pytest
|
import pytest
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -29,25 +36,105 @@ def simple_graph():
|
|||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def simple_batch():
|
||||||
|
return Batch(
|
||||||
|
data=[
|
||||||
|
[
|
||||||
|
BatchData(
|
||||||
|
node_path="1",
|
||||||
|
field_name="prompt",
|
||||||
|
items=[
|
||||||
|
"Tomato sushi",
|
||||||
|
"Strawberry sushi",
|
||||||
|
"Broccoli sushi",
|
||||||
|
"Asparagus sushi",
|
||||||
|
"Tea sushi",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
BatchData(
|
||||||
|
node_path="2",
|
||||||
|
field_name="prompt",
|
||||||
|
items=[
|
||||||
|
"Ume sushi",
|
||||||
|
"Ichigo sushi",
|
||||||
|
"Momo sushi",
|
||||||
|
"Mikan sushi",
|
||||||
|
"Cha sushi",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def graph_with_subgraph():
|
||||||
|
sub_g = Graph()
|
||||||
|
sub_g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
|
||||||
|
sub_g.add_node(TextToImageTestInvocation(id="2"))
|
||||||
|
sub_g.add_edge(create_edge("1", "prompt", "2", "prompt"))
|
||||||
|
g = Graph()
|
||||||
|
g.add_node(GraphInvocation(id="1", graph=sub_g))
|
||||||
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def batch_with_subgraph():
|
||||||
|
return Batch(
|
||||||
|
data=[
|
||||||
|
[
|
||||||
|
BatchData(
|
||||||
|
node_path="1.1",
|
||||||
|
field_name="prompt",
|
||||||
|
items=[
|
||||||
|
"Tomato sushi",
|
||||||
|
"Strawberry sushi",
|
||||||
|
"Broccoli sushi",
|
||||||
|
"Asparagus sushi",
|
||||||
|
"Tea sushi",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
BatchData(
|
||||||
|
node_path="1.2",
|
||||||
|
field_name="prompt",
|
||||||
|
items=[
|
||||||
|
"Ume sushi",
|
||||||
|
"Ichigo sushi",
|
||||||
|
"Momo sushi",
|
||||||
|
"Mikan sushi",
|
||||||
|
"Cha sushi",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
|
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
|
||||||
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
|
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
|
||||||
# the test invocations.
|
# the test invocations.
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_services() -> InvocationServices:
|
def mock_services() -> InvocationServices:
|
||||||
# NOTE: none of these are actually called by the test invocations
|
# NOTE: none of these are actually called by the test invocations
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||||
filename=sqlite_memory, table_name="graph_executions"
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
||||||
)
|
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
|
||||||
return InvocationServices(
|
return InvocationServices(
|
||||||
model_manager=None, # type: ignore
|
model_manager=None, # type: ignore
|
||||||
events=TestEventService(),
|
events=TestEventService(),
|
||||||
logger=None, # type: ignore
|
logger=None, # type: ignore
|
||||||
images=None, # type: ignore
|
images=None, # type: ignore
|
||||||
latents=None, # type: ignore
|
latents=None, # type: ignore
|
||||||
|
batch_manager=BatchManager(batch_manager_storage),
|
||||||
boards=None, # type: ignore
|
boards=None, # type: ignore
|
||||||
board_images=None, # type: ignore
|
board_images=None, # type: ignore
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"),
|
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||||
@ -130,3 +217,120 @@ def test_handles_errors(mock_invoker: Invoker):
|
|||||||
assert g.is_complete()
|
assert g.is_complete()
|
||||||
|
|
||||||
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
|
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_create_batch_with_subgraph(mock_invoker: Invoker, graph_with_subgraph, batch_with_subgraph):
|
||||||
|
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
||||||
|
batch=batch_with_subgraph,
|
||||||
|
graph=graph_with_subgraph,
|
||||||
|
)
|
||||||
|
assert batch_process_res.batch_id
|
||||||
|
# TODO: without the mock events service emitting the `graph_execution_state` events,
|
||||||
|
# the batch sessions do not know when they have finished, so this logic will fail
|
||||||
|
|
||||||
|
# assert len(batch_process_res.session_ids) == 25
|
||||||
|
# mock_invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id)
|
||||||
|
|
||||||
|
# def has_executed_all_batches(batch_id: str):
|
||||||
|
# batch_sessions = mock_invoker.services.batch_manager.get_sessions(batch_id)
|
||||||
|
# print(batch_sessions)
|
||||||
|
# return all((s.state == "completed" for s in batch_sessions))
|
||||||
|
|
||||||
|
# wait_until(lambda: has_executed_all_batches(batch_process_res.batch_id), timeout=10, interval=1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch):
|
||||||
|
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
||||||
|
batch=simple_batch,
|
||||||
|
graph=simple_graph,
|
||||||
|
)
|
||||||
|
assert batch_process_res.batch_id
|
||||||
|
# TODO: without the mock events service emitting the `graph_execution_state` events,
|
||||||
|
# the batch sessions do not know when they have finished, so this logic will fail
|
||||||
|
|
||||||
|
# assert len(batch_process_res.session_ids) == 25
|
||||||
|
# mock_invoker.services.batch_manager.run_batch_process(batch_process_res.batch_id)
|
||||||
|
|
||||||
|
# def has_executed_all_batches(batch_id: str):
|
||||||
|
# batch_sessions = mock_invoker.services.batch_manager.get_sessions(batch_id)
|
||||||
|
# print(batch_sessions)
|
||||||
|
# return all((s.state == "completed" for s in batch_sessions))
|
||||||
|
|
||||||
|
# wait_until(lambda: has_executed_all_batches(batch_process_res.batch_id), timeout=10, interval=1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cannot_create_bad_batches():
|
||||||
|
batch = None
|
||||||
|
try:
|
||||||
|
batch = Batch( # This batch has a duplicate node_path|fieldname combo
|
||||||
|
data=[
|
||||||
|
[
|
||||||
|
BatchData(
|
||||||
|
node_path="1",
|
||||||
|
field_name="prompt",
|
||||||
|
items=[
|
||||||
|
"Tomato sushi",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
BatchData(
|
||||||
|
node_path="1",
|
||||||
|
field_name="prompt",
|
||||||
|
items=[
|
||||||
|
"Ume sushi",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
assert e
|
||||||
|
try:
|
||||||
|
batch = Batch( # This batch has different item list lengths in the same group
|
||||||
|
data=[
|
||||||
|
[
|
||||||
|
BatchData(
|
||||||
|
node_path="1",
|
||||||
|
field_name="prompt",
|
||||||
|
items=[
|
||||||
|
"Tomato sushi",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
BatchData(
|
||||||
|
node_path="1",
|
||||||
|
field_name="prompt",
|
||||||
|
items=[
|
||||||
|
"Tomato sushi",
|
||||||
|
"Courgette sushi",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
BatchData(
|
||||||
|
node_path="1",
|
||||||
|
field_name="prompt",
|
||||||
|
items=[
|
||||||
|
"Ume sushi",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
assert e
|
||||||
|
try:
|
||||||
|
batch = Batch( # This batch has a type mismatch in single items list
|
||||||
|
data=[
|
||||||
|
[
|
||||||
|
BatchData(
|
||||||
|
node_path="1",
|
||||||
|
field_name="prompt",
|
||||||
|
items=["Tomato sushi", 5],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
assert e
|
||||||
|
assert not batch
|
||||||
|
@ -51,6 +51,7 @@ class ImageTestInvocationOutput(BaseInvocationOutput):
|
|||||||
@invocation("test_text_to_image")
|
@invocation("test_text_to_image")
|
||||||
class TextToImageTestInvocation(BaseInvocation):
|
class TextToImageTestInvocation(BaseInvocation):
|
||||||
prompt: str = Field(default="")
|
prompt: str = Field(default="")
|
||||||
|
prompt2: str = Field(default="")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
||||||
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
||||||
|
@ -1,20 +1,27 @@
|
|||||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
|
|
||||||
class TestModel(BaseModel):
|
class TestModel(BaseModel):
|
||||||
id: str = Field(description="ID")
|
id: str = Field(description="ID")
|
||||||
name: str = Field(description="Name")
|
name: str = Field(description="Name")
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_can_create_and_get():
|
@pytest.fixture
|
||||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
def db() -> SqliteItemStorage[TestModel]:
|
||||||
|
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||||
|
return SqliteItemStorage[TestModel](db_conn, "test", "id")
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqlite_service_can_create_and_get(db: SqliteItemStorage[TestModel]):
|
||||||
db.set(TestModel(id="1", name="Test"))
|
db.set(TestModel(id="1", name="Test"))
|
||||||
assert db.get("1") == TestModel(id="1", name="Test")
|
assert db.get("1") == TestModel(id="1", name="Test")
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_can_list():
|
def test_sqlite_service_can_list(db: SqliteItemStorage[TestModel]):
|
||||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
|
||||||
db.set(TestModel(id="1", name="Test"))
|
db.set(TestModel(id="1", name="Test"))
|
||||||
db.set(TestModel(id="2", name="Test"))
|
db.set(TestModel(id="2", name="Test"))
|
||||||
db.set(TestModel(id="3", name="Test"))
|
db.set(TestModel(id="3", name="Test"))
|
||||||
@ -30,15 +37,13 @@ def test_sqlite_service_can_list():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_can_delete():
|
def test_sqlite_service_can_delete(db: SqliteItemStorage[TestModel]):
|
||||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
|
||||||
db.set(TestModel(id="1", name="Test"))
|
db.set(TestModel(id="1", name="Test"))
|
||||||
db.delete("1")
|
db.delete("1")
|
||||||
assert db.get("1") is None
|
assert db.get("1") is None
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_calls_set_callback():
|
def test_sqlite_service_calls_set_callback(db: SqliteItemStorage[TestModel]):
|
||||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
|
||||||
called = False
|
called = False
|
||||||
|
|
||||||
def on_changed(item: TestModel):
|
def on_changed(item: TestModel):
|
||||||
@ -50,8 +55,7 @@ def test_sqlite_service_calls_set_callback():
|
|||||||
assert called
|
assert called
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_calls_delete_callback():
|
def test_sqlite_service_calls_delete_callback(db: SqliteItemStorage[TestModel]):
|
||||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
|
||||||
called = False
|
called = False
|
||||||
|
|
||||||
def on_deleted(item_id: str):
|
def on_deleted(item_id: str):
|
||||||
@ -64,8 +68,7 @@ def test_sqlite_service_calls_delete_callback():
|
|||||||
assert called
|
assert called
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_can_list_with_pagination():
|
def test_sqlite_service_can_list_with_pagination(db: SqliteItemStorage[TestModel]):
|
||||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
|
||||||
db.set(TestModel(id="1", name="Test"))
|
db.set(TestModel(id="1", name="Test"))
|
||||||
db.set(TestModel(id="2", name="Test"))
|
db.set(TestModel(id="2", name="Test"))
|
||||||
db.set(TestModel(id="3", name="Test"))
|
db.set(TestModel(id="3", name="Test"))
|
||||||
@ -77,8 +80,7 @@ def test_sqlite_service_can_list_with_pagination():
|
|||||||
assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")]
|
assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")]
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_can_list_with_pagination_and_offset():
|
def test_sqlite_service_can_list_with_pagination_and_offset(db: SqliteItemStorage[TestModel]):
|
||||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
|
||||||
db.set(TestModel(id="1", name="Test"))
|
db.set(TestModel(id="1", name="Test"))
|
||||||
db.set(TestModel(id="2", name="Test"))
|
db.set(TestModel(id="2", name="Test"))
|
||||||
db.set(TestModel(id="3", name="Test"))
|
db.set(TestModel(id="3", name="Test"))
|
||||||
@ -90,8 +92,7 @@ def test_sqlite_service_can_list_with_pagination_and_offset():
|
|||||||
assert results.items == [TestModel(id="3", name="Test")]
|
assert results.items == [TestModel(id="3", name="Test")]
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_can_search():
|
def test_sqlite_service_can_search(db: SqliteItemStorage[TestModel]):
|
||||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
|
||||||
db.set(TestModel(id="1", name="Test"))
|
db.set(TestModel(id="1", name="Test"))
|
||||||
db.set(TestModel(id="2", name="Test"))
|
db.set(TestModel(id="2", name="Test"))
|
||||||
db.set(TestModel(id="3", name="Test"))
|
db.set(TestModel(id="3", name="Test"))
|
||||||
@ -107,8 +108,7 @@ def test_sqlite_service_can_search():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_can_search_with_pagination():
|
def test_sqlite_service_can_search_with_pagination(db: SqliteItemStorage[TestModel]):
|
||||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
|
||||||
db.set(TestModel(id="1", name="Test"))
|
db.set(TestModel(id="1", name="Test"))
|
||||||
db.set(TestModel(id="2", name="Test"))
|
db.set(TestModel(id="2", name="Test"))
|
||||||
db.set(TestModel(id="3", name="Test"))
|
db.set(TestModel(id="3", name="Test"))
|
||||||
@ -120,8 +120,7 @@ def test_sqlite_service_can_search_with_pagination():
|
|||||||
assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")]
|
assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")]
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_can_search_with_pagination_and_offset():
|
def test_sqlite_service_can_search_with_pagination_and_offset(db: SqliteItemStorage[TestModel]):
|
||||||
db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
|
|
||||||
db.set(TestModel(id="1", name="Test"))
|
db.set(TestModel(id="1", name="Test"))
|
||||||
db.set(TestModel(id="2", name="Test"))
|
db.set(TestModel(id="2", name="Test"))
|
||||||
db.set(TestModel(id="3", name="Test"))
|
db.set(TestModel(id="3", name="Test"))
|
||||||
|
Reference in New Issue
Block a user