mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Run python black
This commit is contained in:
parent
846e52f2ea
commit
f8d8b16267
@ -19,6 +19,7 @@ from invokeai.app.services.batch_manager_storage import (
|
||||
BatchSessionChanges,
|
||||
)
|
||||
|
||||
|
||||
class BatchProcessResponse(BaseModel):
|
||||
batch_id: str = Field(description="ID for the batch")
|
||||
session_ids: list[str] = Field(description="List of session IDs created for this batch")
|
||||
@ -75,9 +76,7 @@ class BatchManager(BatchManagerBase):
|
||||
batch_session = self.__batch_process_storage.get_session(data["graph_execution_state_id"])
|
||||
if not batch_session:
|
||||
return
|
||||
updateSession = BatchSessionChanges(
|
||||
state='error' if err else 'completed'
|
||||
)
|
||||
updateSession = BatchSessionChanges(state="error" if err else "completed")
|
||||
batch_session = self.__batch_process_storage.update_session_state(
|
||||
batch_session.batch_id,
|
||||
batch_session.session_id,
|
||||
@ -140,11 +139,7 @@ class BatchManager(BatchManagerBase):
|
||||
for bi in all_batch_indices:
|
||||
ges = self._create_batch_session(batch_process, bi)
|
||||
self.__invoker.services.graph_execution_manager.set(ges)
|
||||
batch_session = BatchSession(
|
||||
batch_id=batch_process.batch_id,
|
||||
session_id=ges.id,
|
||||
state="created"
|
||||
)
|
||||
batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created")
|
||||
sessions.append(self.__batch_process_storage.create_session(batch_session))
|
||||
return sessions
|
||||
|
||||
|
@ -25,6 +25,7 @@ InvocationsUnion = Union[invocations] # type: ignore
|
||||
|
||||
BatchDataType = Union[str, int, float, ImageField]
|
||||
|
||||
|
||||
class Batch(BaseModel):
|
||||
data: list[dict[str, BatchDataType]] = Field(description="Mapping of node field to data value")
|
||||
node_id: str = Field(description="ID of the node to batch")
|
||||
@ -42,6 +43,7 @@ def uuid_string():
|
||||
res = uuid.uuid4()
|
||||
return str(res)
|
||||
|
||||
|
||||
class BatchProcess(BaseModel):
|
||||
batch_id: Optional[str] = Field(default_factory=uuid_string, description="Identifier for this batch")
|
||||
batches: List[Batch] = Field(
|
||||
@ -141,26 +143,17 @@ class BatchProcessStorageBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_session(
|
||||
self,
|
||||
session_id: str
|
||||
) -> BatchSession:
|
||||
def get_session(self, session_id: str) -> BatchSession:
|
||||
"""Gets session by session_id"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_created_session(
|
||||
self,
|
||||
batch_id: str
|
||||
) -> BatchSession:
|
||||
def get_created_session(self, batch_id: str) -> BatchSession:
|
||||
"""Gets all created Batch Sessions for a given Batch Process id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_created_sessions(
|
||||
self,
|
||||
batch_id: str
|
||||
) -> List[BatchSession]:
|
||||
def get_created_sessions(self, batch_id: str) -> List[BatchSession]:
|
||||
"""Gets all created Batch Sessions for a given Batch Process id."""
|
||||
pass
|
||||
|
||||
@ -339,10 +332,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
batches = json.loads(batches_raw)
|
||||
batches = [parse_raw_as(Batch, batch) for batch in batches]
|
||||
return BatchProcess(
|
||||
batch_id=batch_id,
|
||||
batches=batches,
|
||||
graph=parse_raw_as(Graph, graph_raw),
|
||||
canceled = canceled == 1
|
||||
batch_id=batch_id, batches=batches, graph=parse_raw_as(Graph, graph_raw), canceled=canceled == 1
|
||||
)
|
||||
|
||||
def get(
|
||||
@ -357,7 +347,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
FROM batch_process
|
||||
WHERE batch_id = ?;
|
||||
""",
|
||||
(batch_id,)
|
||||
(batch_id,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
@ -370,7 +360,6 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
raise BatchProcessNotFoundException
|
||||
return self._deserialize_batch_process(dict(result))
|
||||
|
||||
|
||||
def cancel(
|
||||
self,
|
||||
batch_id: str,
|
||||
@ -413,11 +402,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
self._lock.release()
|
||||
return self.get_session(session.session_id)
|
||||
|
||||
|
||||
def get_session(
|
||||
self,
|
||||
session_id: str
|
||||
) -> BatchSession:
|
||||
def get_session(self, session_id: str) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
@ -454,11 +439,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
state=state,
|
||||
)
|
||||
|
||||
|
||||
def get_created_session(
|
||||
self,
|
||||
batch_id: str
|
||||
) -> BatchSession:
|
||||
def get_created_session(self, batch_id: str) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
@ -481,11 +462,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
session = self._deserialize_batch_session(dict(result))
|
||||
return session
|
||||
|
||||
|
||||
def get_created_sessions(
|
||||
self,
|
||||
batch_id: str
|
||||
) -> List[BatchSession]:
|
||||
def get_created_sessions(self, batch_id: str) -> List[BatchSession]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
@ -497,7 +474,6 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
(batch_id,),
|
||||
)
|
||||
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
@ -509,7 +485,6 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
|
||||
return sessions
|
||||
|
||||
|
||||
def update_session_state(
|
||||
self,
|
||||
batch_id: str,
|
||||
|
@ -29,7 +29,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
self._conn = sqlite3.connect(
|
||||
self._filename, check_same_thread=False
|
||||
) # TODO: figure out a better threading solution
|
||||
self._conn.execute('pragma journal_mode=wal')
|
||||
self._conn.execute("pragma journal_mode=wal")
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
self._create_table()
|
||||
|
@ -55,7 +55,7 @@ def simple_batches():
|
||||
{
|
||||
"prompt": "Tea sushi",
|
||||
},
|
||||
]
|
||||
],
|
||||
),
|
||||
Batch(
|
||||
node_id="2",
|
||||
@ -75,8 +75,8 @@ def simple_batches():
|
||||
{
|
||||
"prompt2": "Cha sushi",
|
||||
},
|
||||
]
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
return batches
|
||||
|
||||
@ -185,6 +185,7 @@ def test_handles_errors(mock_invoker: Invoker):
|
||||
|
||||
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
|
||||
|
||||
|
||||
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batches):
|
||||
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
||||
batches=simple_batches,
|
||||
|
Loading…
Reference in New Issue
Block a user