mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
708 lines
23 KiB
Python
708 lines
23 KiB
Python
import sqlite3
|
|
import threading
|
|
import uuid
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Literal, Optional, Union, cast
|
|
|
|
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr, parse_raw_as, validator
|
|
|
|
from invokeai.app.invocations.primitives import ImageField
|
|
from invokeai.app.services.graph import Graph
|
|
|
|
BatchDataType = Union[StrictStr, StrictInt, StrictFloat, ImageField]
|
|
|
|
|
|
class BatchData(BaseModel):
|
|
"""
|
|
A batch data collection.
|
|
"""
|
|
|
|
node_path: str = Field(description="The node into which this batch data collection will be substituted.")
|
|
field_name: str = Field(description="The field into which this batch data collection will be substituted.")
|
|
items: list[BatchDataType] = Field(
|
|
default_factory=list, description="The list of items to substitute into the node/field."
|
|
)
|
|
|
|
|
|
class Batch(BaseModel):
|
|
"""
|
|
A batch, consisting of a list of a list of batch data collections.
|
|
|
|
First, each inner list[BatchData] is zipped into a single batch data collection.
|
|
|
|
Then, the final batch collection is created by taking the Cartesian product of all batch data collections.
|
|
"""
|
|
|
|
data: list[list[BatchData]] = Field(default_factory=list, description="The list of batch data collections.")
|
|
runs: int = Field(default=1, description="Int stating how many times to iterate through all possible batch indices")
|
|
|
|
@validator("runs")
|
|
def validate_positive_runs(cls, r: int):
|
|
if r < 1:
|
|
raise ValueError("runs must be a positive integer")
|
|
return r
|
|
|
|
@validator("data")
|
|
def validate_len(cls, v: list[list[BatchData]]):
|
|
for batch_data in v:
|
|
if any(len(batch_data[0].items) != len(i.items) for i in batch_data):
|
|
raise ValueError("Zipped batch items must have all have same length")
|
|
return v
|
|
|
|
@validator("data")
|
|
def validate_types(cls, v: list[list[BatchData]]):
|
|
for batch_data in v:
|
|
for datum in batch_data:
|
|
for item in datum.items:
|
|
if not all(isinstance(item, type(i)) for i in datum.items):
|
|
raise TypeError("All items in a batch must have have same type")
|
|
return v
|
|
|
|
@validator("data")
|
|
def validate_unique_field_mappings(cls, v: list[list[BatchData]]):
|
|
paths: set[tuple[str, str]] = set()
|
|
count: int = 0
|
|
for batch_data in v:
|
|
for datum in batch_data:
|
|
paths.add((datum.node_path, datum.field_name))
|
|
count += 1
|
|
if len(paths) != count:
|
|
raise ValueError("Each batch data must have unique node_id and field_name")
|
|
return v
|
|
|
|
|
|
def uuid_string():
|
|
res = uuid.uuid4()
|
|
return str(res)
|
|
|
|
|
|
BATCH_SESSION_STATE = Literal["uninitialized", "in_progress", "completed", "error"]
|
|
|
|
|
|
class BatchSession(BaseModel):
|
|
batch_id: str = Field(defaultdescription="The Batch to which this BatchSession is attached.")
|
|
session_id: str = Field(
|
|
default_factory=uuid_string, description="The Session to which this BatchSession is attached."
|
|
)
|
|
batch_index: int = Field(description="The index of this batch session in its parent batch process")
|
|
state: BATCH_SESSION_STATE = Field(default="uninitialized", description="The state of this BatchSession")
|
|
|
|
|
|
class BatchProcess(BaseModel):
|
|
batch_id: str = Field(default_factory=uuid_string, description="Identifier for this batch.")
|
|
batch: Batch = Field(description="The Batch to apply to this session.")
|
|
current_batch_index: int = Field(default=0, description="The last executed batch index")
|
|
current_run: int = Field(default=0, description="The current run of the batch")
|
|
canceled: bool = Field(description="Whether or not to run sessions from this batch.", default=False)
|
|
graph: Graph = Field(description="The graph into which batch data will be inserted before being executed.")
|
|
|
|
|
|
class BatchSessionChanges(BaseModel, extra=Extra.forbid):
|
|
state: BATCH_SESSION_STATE = Field(description="The state of this BatchSession")
|
|
|
|
|
|
class BatchProcessNotFoundException(Exception):
|
|
"""Raised when an Batch Process record is not found."""
|
|
|
|
def __init__(self, message="BatchProcess record not found"):
|
|
super().__init__(message)
|
|
|
|
|
|
class BatchProcessSaveException(Exception):
|
|
"""Raised when an Batch Process record cannot be saved."""
|
|
|
|
def __init__(self, message="BatchProcess record not saved"):
|
|
super().__init__(message)
|
|
|
|
|
|
class BatchProcessDeleteException(Exception):
|
|
"""Raised when an Batch Process record cannot be deleted."""
|
|
|
|
def __init__(self, message="BatchProcess record not deleted"):
|
|
super().__init__(message)
|
|
|
|
|
|
class BatchSessionNotFoundException(Exception):
|
|
"""Raised when an Batch Session record is not found."""
|
|
|
|
def __init__(self, message="BatchSession record not found"):
|
|
super().__init__(message)
|
|
|
|
|
|
class BatchSessionSaveException(Exception):
|
|
"""Raised when an Batch Session record cannot be saved."""
|
|
|
|
def __init__(self, message="BatchSession record not saved"):
|
|
super().__init__(message)
|
|
|
|
|
|
class BatchSessionDeleteException(Exception):
|
|
"""Raised when an Batch Session record cannot be deleted."""
|
|
|
|
def __init__(self, message="BatchSession record not deleted"):
|
|
super().__init__(message)
|
|
|
|
|
|
class BatchProcessStorageBase(ABC):
|
|
"""Low-level service responsible for interfacing with the Batch Process record store."""
|
|
|
|
@abstractmethod
|
|
def delete(self, batch_id: str) -> None:
|
|
"""Deletes a BatchProcess record."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def save(
|
|
self,
|
|
batch_process: BatchProcess,
|
|
) -> BatchProcess:
|
|
"""Saves a BatchProcess record."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get(
|
|
self,
|
|
batch_id: str,
|
|
) -> BatchProcess:
|
|
"""Gets a BatchProcess record."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_all(
|
|
self,
|
|
) -> list[BatchProcess]:
|
|
"""Gets a BatchProcess record."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_incomplete(
|
|
self,
|
|
) -> list[BatchProcess]:
|
|
"""Gets a BatchProcess record."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def start(
|
|
self,
|
|
batch_id: str,
|
|
) -> None:
|
|
"""'Starts' a BatchProcess record by marking its `canceled` attribute to False."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def cancel(
|
|
self,
|
|
batch_id: str,
|
|
) -> None:
|
|
"""'Cancels' a BatchProcess record by setting its `canceled` attribute to True."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def create_session(
|
|
self,
|
|
session: BatchSession,
|
|
) -> BatchSession:
|
|
"""Creates a BatchSession attached to a BatchProcess."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def create_sessions(
|
|
self,
|
|
sessions: list[BatchSession],
|
|
) -> list[BatchSession]:
|
|
"""Creates many BatchSessions attached to a BatchProcess."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_session_by_session_id(self, session_id: str) -> BatchSession:
|
|
"""Gets a BatchSession by session_id"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_sessions_by_batch_id(self, batch_id: str) -> List[BatchSession]:
|
|
"""Gets all BatchSession's for a given BatchProcess id."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_sessions_by_session_ids(self, session_ids: list[str]) -> List[BatchSession]:
|
|
"""Gets all BatchSession's for a given list of session ids."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_next_session(self, batch_id: str) -> BatchSession:
|
|
"""Gets the next BatchSession with state `uninitialized`, for a given BatchProcess id."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def update_session_state(
|
|
self,
|
|
batch_id: str,
|
|
session_id: str,
|
|
changes: BatchSessionChanges,
|
|
) -> BatchSession:
|
|
"""Updates the state of a BatchSession record."""
|
|
pass
|
|
|
|
|
|
class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|
_conn: sqlite3.Connection
|
|
_cursor: sqlite3.Cursor
|
|
_lock: threading.Lock
|
|
|
|
def __init__(self, conn: sqlite3.Connection) -> None:
|
|
super().__init__()
|
|
self._conn = conn
|
|
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
|
self._conn.row_factory = sqlite3.Row
|
|
self._cursor = self._conn.cursor()
|
|
self._lock = threading.Lock()
|
|
|
|
try:
|
|
self._lock.acquire()
|
|
# Enable foreign keys
|
|
self._conn.execute("PRAGMA foreign_keys = ON;")
|
|
self._create_tables()
|
|
self._conn.commit()
|
|
finally:
|
|
self._lock.release()
|
|
|
|
def _create_tables(self) -> None:
|
|
"""Creates the `batch_process` table and `batch_session` junction table."""
|
|
|
|
# Create the `batch_process` table.
|
|
self._cursor.execute(
|
|
"""--sql
|
|
CREATE TABLE IF NOT EXISTS batch_process (
|
|
batch_id TEXT NOT NULL PRIMARY KEY,
|
|
batch TEXT NOT NULL,
|
|
graph TEXT NOT NULL,
|
|
current_batch_index NUMBER NOT NULL,
|
|
current_run NUMBER NOT NULL,
|
|
canceled BOOLEAN NOT NULL DEFAULT(0),
|
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
-- Updated via trigger
|
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
-- Soft delete, currently unused
|
|
deleted_at DATETIME
|
|
);
|
|
"""
|
|
)
|
|
|
|
self._cursor.execute(
|
|
"""--sql
|
|
CREATE INDEX IF NOT EXISTS idx_batch_process_created_at ON batch_process (created_at);
|
|
"""
|
|
)
|
|
|
|
# Add trigger for `updated_at`.
|
|
self._cursor.execute(
|
|
"""--sql
|
|
CREATE TRIGGER IF NOT EXISTS tg_batch_process_updated_at
|
|
AFTER UPDATE
|
|
ON batch_process FOR EACH ROW
|
|
BEGIN
|
|
UPDATE batch_process SET updated_at = current_timestamp
|
|
WHERE batch_id = old.batch_id;
|
|
END;
|
|
"""
|
|
)
|
|
|
|
# Create the `batch_session` junction table.
|
|
self._cursor.execute(
|
|
"""--sql
|
|
CREATE TABLE IF NOT EXISTS batch_session (
|
|
batch_id TEXT NOT NULL,
|
|
session_id TEXT NOT NULL,
|
|
state TEXT NOT NULL,
|
|
batch_index NUMBER NOT NULL,
|
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
-- updated via trigger
|
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
-- Soft delete, currently unused
|
|
deleted_at DATETIME,
|
|
-- enforce one-to-many relationship between batch_process and batch_session using PK
|
|
-- (we can extend this to many-to-many later)
|
|
PRIMARY KEY (batch_id,session_id),
|
|
FOREIGN KEY (batch_id) REFERENCES batch_process (batch_id) ON DELETE CASCADE
|
|
);
|
|
"""
|
|
)
|
|
|
|
# Add index for batch id
|
|
self._cursor.execute(
|
|
"""--sql
|
|
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id ON batch_session (batch_id);
|
|
"""
|
|
)
|
|
|
|
# Add index for batch id, sorted by created_at
|
|
self._cursor.execute(
|
|
"""--sql
|
|
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id_created_at ON batch_session (batch_id,created_at);
|
|
"""
|
|
)
|
|
|
|
# Add trigger for `updated_at`.
|
|
self._cursor.execute(
|
|
"""--sql
|
|
CREATE TRIGGER IF NOT EXISTS tg_batch_session_updated_at
|
|
AFTER UPDATE
|
|
ON batch_session FOR EACH ROW
|
|
BEGIN
|
|
UPDATE batch_session SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
|
WHERE batch_id = old.batch_id AND session_id = old.session_id;
|
|
END;
|
|
"""
|
|
)
|
|
|
|
def delete(self, batch_id: str) -> None:
|
|
try:
|
|
self._lock.acquire()
|
|
self._cursor.execute(
|
|
"""--sql
|
|
DELETE FROM batch_process
|
|
WHERE batch_id = ?;
|
|
""",
|
|
(batch_id,),
|
|
)
|
|
self._conn.commit()
|
|
except sqlite3.Error as e:
|
|
self._conn.rollback()
|
|
raise BatchProcessDeleteException from e
|
|
except Exception as e:
|
|
self._conn.rollback()
|
|
raise BatchProcessDeleteException from e
|
|
finally:
|
|
self._lock.release()
|
|
|
|
def save(
|
|
self,
|
|
batch_process: BatchProcess,
|
|
) -> BatchProcess:
|
|
try:
|
|
self._lock.acquire()
|
|
self._cursor.execute(
|
|
"""--sql
|
|
INSERT OR REPLACE INTO batch_process (batch_id, batch, graph, current_batch_index, current_run)
|
|
VALUES (?, ?, ?, ?, ?);
|
|
""",
|
|
(
|
|
batch_process.batch_id,
|
|
batch_process.batch.json(),
|
|
batch_process.graph.json(),
|
|
batch_process.current_batch_index,
|
|
batch_process.current_run,
|
|
),
|
|
)
|
|
self._conn.commit()
|
|
except sqlite3.Error as e:
|
|
self._conn.rollback()
|
|
raise BatchProcessSaveException from e
|
|
finally:
|
|
self._lock.release()
|
|
return self.get(batch_process.batch_id)
|
|
|
|
def _deserialize_batch_process(self, session_dict: dict) -> BatchProcess:
|
|
"""Deserializes a batch session."""
|
|
|
|
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
|
|
|
batch_id = session_dict.get("batch_id", "unknown")
|
|
batch_raw = session_dict.get("batch", "unknown")
|
|
graph_raw = session_dict.get("graph", "unknown")
|
|
current_batch_index = session_dict.get("current_batch_index", 0)
|
|
current_run = session_dict.get("current_run", 0)
|
|
canceled = session_dict.get("canceled", 0)
|
|
return BatchProcess(
|
|
batch_id=batch_id,
|
|
batch=parse_raw_as(Batch, batch_raw),
|
|
graph=parse_raw_as(Graph, graph_raw),
|
|
current_batch_index=current_batch_index,
|
|
current_run=current_run,
|
|
canceled=canceled == 1,
|
|
)
|
|
|
|
def get(
|
|
self,
|
|
batch_id: str,
|
|
) -> BatchProcess:
|
|
try:
|
|
self._lock.acquire()
|
|
self._cursor.execute(
|
|
"""--sql
|
|
SELECT *
|
|
FROM batch_process
|
|
WHERE batch_id = ?;
|
|
""",
|
|
(batch_id,),
|
|
)
|
|
|
|
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
|
except sqlite3.Error as e:
|
|
self._conn.rollback()
|
|
raise BatchProcessNotFoundException from e
|
|
finally:
|
|
self._lock.release()
|
|
if result is None:
|
|
raise BatchProcessNotFoundException
|
|
return self._deserialize_batch_process(dict(result))
|
|
|
|
def get_all(
|
|
self,
|
|
) -> list[BatchProcess]:
|
|
try:
|
|
self._lock.acquire()
|
|
self._cursor.execute(
|
|
"""--sql
|
|
SELECT *
|
|
FROM batch_process
|
|
"""
|
|
)
|
|
|
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
|
except sqlite3.Error as e:
|
|
self._conn.rollback()
|
|
raise BatchProcessNotFoundException from e
|
|
finally:
|
|
self._lock.release()
|
|
if result is None:
|
|
return list()
|
|
return list(map(lambda r: self._deserialize_batch_process(dict(r)), result))
|
|
|
|
def get_incomplete(
|
|
self,
|
|
) -> list[BatchProcess]:
|
|
try:
|
|
self._lock.acquire()
|
|
self._cursor.execute(
|
|
"""--sql
|
|
SELECT bp.*
|
|
FROM batch_process bp
|
|
WHERE bp.batch_id IN
|
|
(
|
|
SELECT batch_id
|
|
FROM batch_session bs
|
|
WHERE state IN ('uninitialized', 'in_progress')
|
|
);
|
|
"""
|
|
)
|
|
|
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
|
except sqlite3.Error as e:
|
|
self._conn.rollback()
|
|
raise BatchProcessNotFoundException from e
|
|
finally:
|
|
self._lock.release()
|
|
if result is None:
|
|
return list()
|
|
return list(map(lambda r: self._deserialize_batch_process(dict(r)), result))
|
|
|
|
def start(
|
|
self,
|
|
batch_id: str,
|
|
) -> None:
|
|
try:
|
|
self._lock.acquire()
|
|
self._cursor.execute(
|
|
"""--sql
|
|
UPDATE batch_process
|
|
SET canceled = 0
|
|
WHERE batch_id = ?;
|
|
""",
|
|
(batch_id,),
|
|
)
|
|
self._conn.commit()
|
|
except sqlite3.Error as e:
|
|
self._conn.rollback()
|
|
raise BatchSessionSaveException from e
|
|
finally:
|
|
self._lock.release()
|
|
|
|
def cancel(
|
|
self,
|
|
batch_id: str,
|
|
) -> None:
|
|
try:
|
|
self._lock.acquire()
|
|
self._cursor.execute(
|
|
"""--sql
|
|
UPDATE batch_process
|
|
SET canceled = 1
|
|
WHERE batch_id = ?;
|
|
""",
|
|
(batch_id,),
|
|
)
|
|
self._conn.commit()
|
|
except sqlite3.Error as e:
|
|
self._conn.rollback()
|
|
raise BatchSessionSaveException from e
|
|
finally:
|
|
self._lock.release()
|
|
|
|
def create_session(
|
|
self,
|
|
session: BatchSession,
|
|
) -> BatchSession:
|
|
try:
|
|
self._lock.acquire()
|
|
self._cursor.execute(
|
|
"""--sql
|
|
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
|
|
VALUES (?, ?, ?, ?);
|
|
""",
|
|
(session.batch_id, session.session_id, session.state, session.batch_index),
|
|
)
|
|
self._conn.commit()
|
|
except sqlite3.Error as e:
|
|
self._conn.rollback()
|
|
raise BatchSessionSaveException from e
|
|
finally:
|
|
self._lock.release()
|
|
return self.get_session_by_session_id(session.session_id)
|
|
|
|
def create_sessions(
|
|
self,
|
|
sessions: list[BatchSession],
|
|
) -> list[BatchSession]:
|
|
try:
|
|
self._lock.acquire()
|
|
session_data = [(session.batch_id, session.session_id, session.state) for session in sessions]
|
|
self._cursor.executemany(
|
|
"""--sql
|
|
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
|
|
VALUES (?, ?, ?, ?);
|
|
""",
|
|
session_data,
|
|
)
|
|
self._conn.commit()
|
|
except sqlite3.Error as e:
|
|
self._conn.rollback()
|
|
raise BatchSessionSaveException from e
|
|
finally:
|
|
self._lock.release()
|
|
return self.get_sessions_by_session_ids([session.session_id for session in sessions])
|
|
|
|
def get_session_by_session_id(self, session_id: str) -> BatchSession:
|
|
try:
|
|
self._lock.acquire()
|
|
self._cursor.execute(
|
|
"""--sql
|
|
SELECT *
|
|
FROM batch_session
|
|
WHERE session_id= ?;
|
|
""",
|
|
(session_id,),
|
|
)
|
|
|
|
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
|
except sqlite3.Error as e:
|
|
self._conn.rollback()
|
|
raise BatchSessionNotFoundException from e
|
|
finally:
|
|
self._lock.release()
|
|
if result is None:
|
|
raise BatchSessionNotFoundException
|
|
return self._deserialize_batch_session(dict(result))
|
|
|
|
def _deserialize_batch_session(self, session_dict: dict) -> BatchSession:
|
|
"""Deserializes a batch session."""
|
|
|
|
return BatchSession.parse_obj(session_dict)
|
|
|
|
def get_next_session(self, batch_id: str) -> BatchSession:
|
|
try:
|
|
self._lock.acquire()
|
|
self._cursor.execute(
|
|
"""--sql
|
|
SELECT *
|
|
FROM batch_session
|
|
WHERE batch_id = ? AND state = 'uninitialized';
|
|
""",
|
|
(batch_id,),
|
|
)
|
|
|
|
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
|
|
except sqlite3.Error as e:
|
|
self._conn.rollback()
|
|
raise BatchSessionNotFoundException from e
|
|
finally:
|
|
self._lock.release()
|
|
if result is None:
|
|
raise BatchSessionNotFoundException
|
|
session = self._deserialize_batch_session(dict(result))
|
|
return session
|
|
|
|
def get_sessions_by_batch_id(self, batch_id: str) -> List[BatchSession]:
|
|
try:
|
|
self._lock.acquire()
|
|
self._cursor.execute(
|
|
"""--sql
|
|
SELECT *
|
|
FROM batch_session
|
|
WHERE batch_id = ?;
|
|
""",
|
|
(batch_id,),
|
|
)
|
|
|
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
|
except sqlite3.Error as e:
|
|
self._conn.rollback()
|
|
raise BatchSessionNotFoundException from e
|
|
finally:
|
|
self._lock.release()
|
|
if result is None:
|
|
raise BatchSessionNotFoundException
|
|
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
|
|
return sessions
|
|
|
|
def get_sessions_by_session_ids(self, session_ids: list[str]) -> List[BatchSession]:
|
|
try:
|
|
self._lock.acquire()
|
|
placeholders = ",".join("?" * len(session_ids))
|
|
self._cursor.execute(
|
|
f"""--sql
|
|
SELECT * FROM batch_session
|
|
WHERE session_id
|
|
IN ({placeholders})
|
|
""",
|
|
tuple(session_ids),
|
|
)
|
|
|
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
|
except sqlite3.Error as e:
|
|
self._conn.rollback()
|
|
raise BatchSessionNotFoundException from e
|
|
finally:
|
|
self._lock.release()
|
|
if result is None:
|
|
raise BatchSessionNotFoundException
|
|
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
|
|
return sessions
|
|
|
|
def update_session_state(
|
|
self,
|
|
batch_id: str,
|
|
session_id: str,
|
|
changes: BatchSessionChanges,
|
|
) -> BatchSession:
|
|
try:
|
|
self._lock.acquire()
|
|
|
|
# Change the state of a batch session
|
|
if changes.state is not None:
|
|
self._cursor.execute(
|
|
"""--sql
|
|
UPDATE batch_session
|
|
SET state = ?
|
|
WHERE batch_id = ? AND session_id = ?;
|
|
""",
|
|
(changes.state, batch_id, session_id),
|
|
)
|
|
self._conn.commit()
|
|
except sqlite3.Error as e:
|
|
self._conn.rollback()
|
|
raise BatchSessionSaveException from e
|
|
finally:
|
|
self._lock.release()
|
|
return self.get_session_by_session_id(session_id)
|