mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Run python black
This commit is contained in:
@ -25,6 +25,7 @@ InvocationsUnion = Union[invocations] # type: ignore
|
||||
|
||||
BatchDataType = Union[str, int, float, ImageField]
|
||||
|
||||
|
||||
class Batch(BaseModel):
|
||||
data: list[dict[str, BatchDataType]] = Field(description="Mapping of node field to data value")
|
||||
node_id: str = Field(description="ID of the node to batch")
|
||||
@ -42,6 +43,7 @@ def uuid_string():
|
||||
res = uuid.uuid4()
|
||||
return str(res)
|
||||
|
||||
|
||||
class BatchProcess(BaseModel):
|
||||
batch_id: Optional[str] = Field(default_factory=uuid_string, description="Identifier for this batch")
|
||||
batches: List[Batch] = Field(
|
||||
@ -134,33 +136,24 @@ class BatchProcessStorageBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def create_session(
|
||||
self,
|
||||
session: BatchSession,
|
||||
self,
|
||||
session: BatchSession,
|
||||
) -> BatchSession:
|
||||
"""Creates a Batch Session attached to a Batch Process."""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_session(
|
||||
self,
|
||||
session_id: str
|
||||
) -> BatchSession:
|
||||
def get_session(self, session_id: str) -> BatchSession:
|
||||
"""Gets session by session_id"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_created_session(
|
||||
self,
|
||||
batch_id: str
|
||||
) -> BatchSession:
|
||||
def get_created_session(self, batch_id: str) -> BatchSession:
|
||||
"""Gets all created Batch Sessions for a given Batch Process id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_created_sessions(
|
||||
self,
|
||||
batch_id: str
|
||||
) -> List[BatchSession]:
|
||||
def get_created_sessions(self, batch_id: str) -> List[BatchSession]:
|
||||
"""Gets all created Batch Sessions for a given Batch Process id."""
|
||||
pass
|
||||
|
||||
@ -339,10 +332,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
batches = json.loads(batches_raw)
|
||||
batches = [parse_raw_as(Batch, batch) for batch in batches]
|
||||
return BatchProcess(
|
||||
batch_id=batch_id,
|
||||
batches=batches,
|
||||
graph=parse_raw_as(Graph, graph_raw),
|
||||
canceled = canceled == 1
|
||||
batch_id=batch_id, batches=batches, graph=parse_raw_as(Graph, graph_raw), canceled=canceled == 1
|
||||
)
|
||||
|
||||
def get(
|
||||
@ -357,7 +347,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
FROM batch_process
|
||||
WHERE batch_id = ?;
|
||||
""",
|
||||
(batch_id,)
|
||||
(batch_id,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
@ -370,7 +360,6 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
raise BatchProcessNotFoundException
|
||||
return self._deserialize_batch_process(dict(result))
|
||||
|
||||
|
||||
def cancel(
|
||||
self,
|
||||
batch_id: str,
|
||||
@ -393,8 +382,8 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
self._lock.release()
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
session: BatchSession,
|
||||
self,
|
||||
session: BatchSession,
|
||||
) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
@ -413,11 +402,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
self._lock.release()
|
||||
return self.get_session(session.session_id)
|
||||
|
||||
|
||||
def get_session(
|
||||
self,
|
||||
session_id: str
|
||||
) -> BatchSession:
|
||||
def get_session(self, session_id: str) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
@ -454,11 +439,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
state=state,
|
||||
)
|
||||
|
||||
|
||||
def get_created_session(
|
||||
self,
|
||||
batch_id: str
|
||||
) -> BatchSession:
|
||||
def get_created_session(self, batch_id: str) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
@ -481,11 +462,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
session = self._deserialize_batch_session(dict(result))
|
||||
return session
|
||||
|
||||
|
||||
def get_created_sessions(
|
||||
self,
|
||||
batch_id: str
|
||||
) -> List[BatchSession]:
|
||||
def get_created_sessions(self, batch_id: str) -> List[BatchSession]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
@ -497,7 +474,6 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
(batch_id,),
|
||||
)
|
||||
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
@ -508,7 +484,6 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
raise BatchSessionNotFoundException
|
||||
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
|
||||
return sessions
|
||||
|
||||
|
||||
def update_session_state(
|
||||
self,
|
||||
@ -536,4 +511,4 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
raise BatchSessionSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get_session(session_id)
|
||||
return self.get_session(session_id)
|
||||
|
Reference in New Issue
Block a user