mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
More testing
This commit is contained in:
@ -15,9 +15,8 @@ import json
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
)
|
||||
from invokeai.app.services.graph import Graph, GraphExecutionState
|
||||
from invokeai.app.services.graph import Graph
|
||||
from invokeai.app.models.image import ImageField
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
|
||||
from pydantic import BaseModel, Field, Extra, parse_raw_as
|
||||
|
||||
@ -169,14 +168,14 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
def __init__(self, filename: str, lock: threading.Lock = threading.Lock()) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
self._lock = lock
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
@ -231,7 +230,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
CREATE TABLE IF NOT EXISTS batch_session (
|
||||
batch_id TEXT NOT NULL,
|
||||
session_id TEXT NOT NULL,
|
||||
state TEXT NOT NULL DEFAULT('created'),
|
||||
state TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
@ -267,7 +266,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
ON batch_session FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE batch_session SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE batch_id = old.batch_id AND image_name = old.image_name;
|
||||
WHERE batch_id = old.batch_id AND session_id = old.session_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
@ -298,12 +297,13 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
) -> BatchProcess:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
batches = [batch.json() for batch in batch_process.batches]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO batch_process (batch_id, batches, graph)
|
||||
VALUES (?, ?, ?);
|
||||
""",
|
||||
(batch_process.batch_id, json.dumps([batch.json() for batch in batch_process.batches]), batch_process.graph.json()),
|
||||
(batch_process.batch_id, json.dumps(batches), batch_process.graph.json()),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
@ -321,10 +321,11 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
batch_id = session_dict.get("batch_id", "unknown")
|
||||
batches_raw = session_dict.get("batches", "unknown")
|
||||
graph_raw = session_dict.get("graph", "unknown")
|
||||
|
||||
batches = json.loads(batches_raw)
|
||||
batches = [parse_raw_as(Batch, batch) for batch in batches]
|
||||
return BatchProcess(
|
||||
batch_id=batch_id,
|
||||
batches=[parse_raw_as(Batch, batch) for batch in json.loads(batches_raw)],
|
||||
batches=batches,
|
||||
graph=parse_raw_as(Graph, graph_raw),
|
||||
)
|
||||
|
||||
@ -398,7 +399,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BatchSessionNotFoundException
|
||||
return BatchSession(**dict(result))
|
||||
return self._deserialize_batch_session(dict(result))
|
||||
|
||||
def _deserialize_batch_session(self, session_dict: dict) -> BatchSession:
|
||||
"""Deserializes a batch session."""
|
||||
|
Reference in New Issue
Block a user