More testing

This commit is contained in:
Brandon Rising 2023-08-10 14:09:00 -04:00
parent e26e4740b3
commit e751f7d815
8 changed files with 40 additions and 41 deletions

View File

@ -2,6 +2,7 @@
from typing import Optional from typing import Optional
from logging import Logger from logging import Logger
import threading
import os import os
from invokeai.app.services.board_image_record_storage import ( from invokeai.app.services.board_image_record_storage import (
SqliteBoardImageRecordStorage, SqliteBoardImageRecordStorage,
@ -65,6 +66,8 @@ class ApiDependencies:
logger.info(f"Root directory = {str(config.root_path)}") logger.info(f"Root directory = {str(config.root_path)}")
logger.debug(f"Internet connectivity is {config.internet_available}") logger.debug(f"Internet connectivity is {config.internet_available}")
lock = threading.Lock()
events = FastAPIEventService(event_handler_id) events = FastAPIEventService(event_handler_id)
output_folder = config.output_path output_folder = config.output_path
@ -74,17 +77,17 @@ class ApiDependencies:
db_location.parent.mkdir(parents=True, exist_ok=True) db_location.parent.mkdir(parents=True, exist_ok=True)
graph_execution_manager = SqliteItemStorage[GraphExecutionState]( graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions" filename=db_location, table_name="graph_executions", lock=lock
) )
urls = LocalUrlService() urls = LocalUrlService()
image_record_storage = SqliteImageRecordStorage(db_location) image_record_storage = SqliteImageRecordStorage(db_location, lock=lock)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images") image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService() names = SimpleNameService()
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")) latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
board_record_storage = SqliteBoardRecordStorage(db_location) board_record_storage = SqliteBoardRecordStorage(db_location, lock=lock)
board_image_record_storage = SqliteBoardImageRecordStorage(db_location) board_image_record_storage = SqliteBoardImageRecordStorage(db_location, lock=lock)
boards = BoardService( boards = BoardService(
services=BoardServiceDependencies( services=BoardServiceDependencies(
@ -118,7 +121,7 @@ class ApiDependencies:
) )
) )
batch_manager_storage = SqliteBatchProcessStorage(db_location) batch_manager_storage = SqliteBatchProcessStorage(db_location, lock=lock)
batch_manager = BatchManager(batch_manager_storage) batch_manager = BatchManager(batch_manager_storage)
services = InvocationServices( services = InvocationServices(
@ -130,7 +133,7 @@ class ApiDependencies:
boards=boards, boards=boards,
board_images=board_images, board_images=board_images,
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"), graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs", lock=lock),
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
configuration=config, configuration=config,

View File

@ -38,6 +38,7 @@ from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService from invokeai.app.services.urls import LocalUrlService
from invokeai.app.services.batch_manager import BatchManager from invokeai.app.services.batch_manager import BatchManager
from invokeai.app.services.batch_manager_storage import SqliteBatchProcessStorage
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
@ -301,13 +302,16 @@ def invoke_cli():
) )
) )
batch_manager_storage = SqliteBatchProcessStorage(db_location)
batch_manager = BatchManager(batch_manager_storage)
services = InvocationServices( services = InvocationServices(
model_manager=model_manager, model_manager=model_manager,
events=events, events=events,
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")), latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
images=images, images=images,
boards=boards, boards=boards,
batch_manager=BatchManager(), batch_manager=batch_manager,
board_images=board_images, board_images=board_images,
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"), graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),

View File

@ -12,6 +12,7 @@ from invokeai.app.services.graph import Graph, GraphExecutionState
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.batch_manager_storage import ( from invokeai.app.services.batch_manager_storage import (
BatchProcessStorageBase, BatchProcessStorageBase,
BatchSessionNotFoundException,
Batch, Batch,
BatchProcess, BatchProcess,
BatchSession, BatchSession,
@ -98,7 +99,10 @@ class BatchManager(BatchManagerBase):
return GraphExecutionState(graph=graph) return GraphExecutionState(graph=graph)
def run_batch_process(self, batch_id: str): def run_batch_process(self, batch_id: str):
created_session = self.__batch_process_storage.get_created_session(batch_id) try:
created_session = self.__batch_process_storage.get_created_session(batch_id)
except BatchSessionNotFoundException:
return
ges = self.__invoker.services.graph_execution_manager.get(created_session.session_id) ges = self.__invoker.services.graph_execution_manager.get(created_session.session_id)
self.__invoker.invoke(ges, invoke_all=True) self.__invoker.invoke(ges, invoke_all=True)

View File

@ -15,9 +15,8 @@ import json
from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.baseinvocation import (
BaseInvocation, BaseInvocation,
) )
from invokeai.app.services.graph import Graph, GraphExecutionState from invokeai.app.services.graph import Graph
from invokeai.app.models.image import ImageField from invokeai.app.models.image import ImageField
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from pydantic import BaseModel, Field, Extra, parse_raw_as from pydantic import BaseModel, Field, Extra, parse_raw_as
@ -169,14 +168,14 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
_cursor: sqlite3.Cursor _cursor: sqlite3.Cursor
_lock: threading.Lock _lock: threading.Lock
def __init__(self, filename: str) -> None: def __init__(self, filename: str, lock: threading.Lock = threading.Lock()) -> None:
super().__init__() super().__init__()
self._filename = filename self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False) self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!) # Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
self._lock = threading.Lock() self._lock = lock
try: try:
self._lock.acquire() self._lock.acquire()
@ -231,7 +230,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
CREATE TABLE IF NOT EXISTS batch_session ( CREATE TABLE IF NOT EXISTS batch_session (
batch_id TEXT NOT NULL, batch_id TEXT NOT NULL,
session_id TEXT NOT NULL, session_id TEXT NOT NULL,
state TEXT NOT NULL DEFAULT('created'), state TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger -- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
@ -267,7 +266,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
ON batch_session FOR EACH ROW ON batch_session FOR EACH ROW
BEGIN BEGIN
UPDATE batch_session SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') UPDATE batch_session SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE batch_id = old.batch_id AND image_name = old.image_name; WHERE batch_id = old.batch_id AND session_id = old.session_id;
END; END;
""" """
) )
@ -298,12 +297,13 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
) -> BatchProcess: ) -> BatchProcess:
try: try:
self._lock.acquire() self._lock.acquire()
batches = [batch.json() for batch in batch_process.batches]
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
INSERT OR IGNORE INTO batch_process (batch_id, batches, graph) INSERT OR IGNORE INTO batch_process (batch_id, batches, graph)
VALUES (?, ?, ?); VALUES (?, ?, ?);
""", """,
(batch_process.batch_id, json.dumps([batch.json() for batch in batch_process.batches]), batch_process.graph.json()), (batch_process.batch_id, json.dumps(batches), batch_process.graph.json()),
) )
self._conn.commit() self._conn.commit()
except sqlite3.Error as e: except sqlite3.Error as e:
@ -321,10 +321,11 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
batch_id = session_dict.get("batch_id", "unknown") batch_id = session_dict.get("batch_id", "unknown")
batches_raw = session_dict.get("batches", "unknown") batches_raw = session_dict.get("batches", "unknown")
graph_raw = session_dict.get("graph", "unknown") graph_raw = session_dict.get("graph", "unknown")
batches = json.loads(batches_raw)
batches = [parse_raw_as(Batch, batch) for batch in batches]
return BatchProcess( return BatchProcess(
batch_id=batch_id, batch_id=batch_id,
batches=[parse_raw_as(Batch, batch) for batch in json.loads(batches_raw)], batches=batches,
graph=parse_raw_as(Graph, graph_raw), graph=parse_raw_as(Graph, graph_raw),
) )
@ -398,7 +399,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
self._lock.release() self._lock.release()
if result is None: if result is None:
raise BatchSessionNotFoundException raise BatchSessionNotFoundException
return BatchSession(**dict(result)) return self._deserialize_batch_session(dict(result))
def _deserialize_batch_session(self, session_dict: dict) -> BatchSession: def _deserialize_batch_session(self, session_dict: dict) -> BatchSession:
"""Deserializes a batch session.""" """Deserializes a batch session."""

View File

@ -62,14 +62,14 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
_cursor: sqlite3.Cursor _cursor: sqlite3.Cursor
_lock: threading.Lock _lock: threading.Lock
def __init__(self, filename: str) -> None: def __init__(self, filename: str, lock: threading.Lock = threading.Lock()) -> None:
super().__init__() super().__init__()
self._filename = filename self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False) self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!) # Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
self._lock = threading.Lock() self._lock = lock
try: try:
self._lock.acquire() self._lock.acquire()

View File

@ -95,14 +95,14 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
_cursor: sqlite3.Cursor _cursor: sqlite3.Cursor
_lock: threading.Lock _lock: threading.Lock
def __init__(self, filename: str) -> None: def __init__(self, filename: str, lock: threading.Lock = threading.Lock()) -> None:
super().__init__() super().__init__()
self._filename = filename self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False) self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!) # Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
self._lock = threading.Lock() self._lock = lock
try: try:
self._lock.acquire() self._lock.acquire()

View File

@ -155,14 +155,14 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
_cursor: sqlite3.Cursor _cursor: sqlite3.Cursor
_lock: threading.Lock _lock: threading.Lock
def __init__(self, filename: str) -> None: def __init__(self, filename: str, lock: threading.Lock = threading.Lock()) -> None:
super().__init__() super().__init__()
self._filename = filename self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False) self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!) # Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
self._lock = threading.Lock() self._lock = lock
try: try:
self._lock.acquire() self._lock.acquire()

View File

@ -20,16 +20,16 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
_id_field: str _id_field: str
_lock: Lock _lock: Lock
def __init__(self, filename: str, table_name: str, id_field: str = "id"): def __init__(self, filename: str, table_name: str, id_field: str = "id", lock: Lock = Lock()):
super().__init__() super().__init__()
self._filename = filename self._filename = filename
self._table_name = table_name self._table_name = table_name
self._id_field = id_field # TODO: validate that T has this field self._id_field = id_field # TODO: validate that T has this field
self._lock = Lock() self._lock = lock
self._conn = sqlite3.connect( self._conn = sqlite3.connect(
self._filename, check_same_thread=False self._filename, check_same_thread=False
) # TODO: figure out a better threading solution ) # TODO: figure out a better threading solution
self._conn.set_trace_callback(print) self._conn.execute('pragma journal_mode=wal')
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
self._create_table() self._create_table()
@ -55,21 +55,12 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def set(self, item: T): def set(self, item: T):
try: try:
self._lock.acquire() self._lock.acquire()
json = item.json()
print('-----------------locking db-----------------')
traceback.print_stack(limit=2)
self._cursor.execute( self._cursor.execute(
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""", f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
(json,), (item.json(),),
) )
self._conn.commit() self._conn.commit()
self._cursor.close()
self._cursor = self._conn.cursor()
print('-----------------unlocking db-----------------')
except Exception as e:
print("Exception!")
print(e)
finally: finally:
self._lock.release() self._lock.release()
self._on_changed(item) self._on_changed(item)
@ -77,12 +68,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def get(self, id: str) -> Optional[T]: def get(self, id: str) -> Optional[T]:
try: try:
self._lock.acquire() self._lock.acquire()
print('-----------------locking db-----------------')
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)) self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
result = self._cursor.fetchone() result = self._cursor.fetchone()
self._cursor.close()
self._cursor = self._conn.cursor()
print('-----------------unlocking db-----------------')
finally: finally:
self._lock.release() self._lock.release()