Run python black

This commit is contained in:
Brandon Rising 2023-08-14 11:01:31 -04:00
parent 846e52f2ea
commit f8d8b16267
4 changed files with 30 additions and 59 deletions

View File

@ -19,6 +19,7 @@ from invokeai.app.services.batch_manager_storage import (
BatchSessionChanges, BatchSessionChanges,
) )
class BatchProcessResponse(BaseModel): class BatchProcessResponse(BaseModel):
batch_id: str = Field(description="ID for the batch") batch_id: str = Field(description="ID for the batch")
session_ids: list[str] = Field(description="List of session IDs created for this batch") session_ids: list[str] = Field(description="List of session IDs created for this batch")
@ -75,9 +76,7 @@ class BatchManager(BatchManagerBase):
batch_session = self.__batch_process_storage.get_session(data["graph_execution_state_id"]) batch_session = self.__batch_process_storage.get_session(data["graph_execution_state_id"])
if not batch_session: if not batch_session:
return return
updateSession = BatchSessionChanges( updateSession = BatchSessionChanges(state="error" if err else "completed")
state='error' if err else 'completed'
)
batch_session = self.__batch_process_storage.update_session_state( batch_session = self.__batch_process_storage.update_session_state(
batch_session.batch_id, batch_session.batch_id,
batch_session.session_id, batch_session.session_id,
@ -105,18 +104,18 @@ class BatchManager(BatchManagerBase):
return GraphExecutionState(graph=graph) return GraphExecutionState(graph=graph)
def run_batch_process(self, batch_id: str): def run_batch_process(self, batch_id: str):
try: try:
created_session = self.__batch_process_storage.get_created_session(batch_id) created_session = self.__batch_process_storage.get_created_session(batch_id)
except BatchSessionNotFoundException: except BatchSessionNotFoundException:
return return
ges = self.__invoker.services.graph_execution_manager.get(created_session.session_id) ges = self.__invoker.services.graph_execution_manager.get(created_session.session_id)
self.__invoker.invoke(ges, invoke_all=True) self.__invoker.invoke(ges, invoke_all=True)
def _valid_batch_config(self, batch_process: BatchProcess) -> bool: def _valid_batch_config(self, batch_process: BatchProcess) -> bool:
# TODO: Check that the node_ids in the batches are unique # TODO: Check that the node_ids in the batches are unique
# TODO: Validate data types are correct for each batch data # TODO: Validate data types are correct for each batch data
return True return True
def create_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcessResponse: def create_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcessResponse:
batch_process = BatchProcess( batch_process = BatchProcess(
batches=batches, batches=batches,
@ -130,7 +129,7 @@ class BatchManager(BatchManagerBase):
batch_id=batch_process.batch_id, batch_id=batch_process.batch_id,
session_ids=[session.session_id for session in sessions], session_ids=[session.session_id for session in sessions],
) )
def _create_sessions(self, batch_process: BatchProcess) -> list[BatchSession]: def _create_sessions(self, batch_process: BatchProcess) -> list[BatchSession]:
batch_indices = list() batch_indices = list()
sessions = list() sessions = list()
@ -140,11 +139,7 @@ class BatchManager(BatchManagerBase):
for bi in all_batch_indices: for bi in all_batch_indices:
ges = self._create_batch_session(batch_process, bi) ges = self._create_batch_session(batch_process, bi)
self.__invoker.services.graph_execution_manager.set(ges) self.__invoker.services.graph_execution_manager.set(ges)
batch_session = BatchSession( batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created")
batch_id=batch_process.batch_id,
session_id=ges.id,
state="created"
)
sessions.append(self.__batch_process_storage.create_session(batch_session)) sessions.append(self.__batch_process_storage.create_session(batch_session))
return sessions return sessions

View File

@ -25,6 +25,7 @@ InvocationsUnion = Union[invocations] # type: ignore
BatchDataType = Union[str, int, float, ImageField] BatchDataType = Union[str, int, float, ImageField]
class Batch(BaseModel): class Batch(BaseModel):
data: list[dict[str, BatchDataType]] = Field(description="Mapping of node field to data value") data: list[dict[str, BatchDataType]] = Field(description="Mapping of node field to data value")
node_id: str = Field(description="ID of the node to batch") node_id: str = Field(description="ID of the node to batch")
@ -42,6 +43,7 @@ def uuid_string():
res = uuid.uuid4() res = uuid.uuid4()
return str(res) return str(res)
class BatchProcess(BaseModel): class BatchProcess(BaseModel):
batch_id: Optional[str] = Field(default_factory=uuid_string, description="Identifier for this batch") batch_id: Optional[str] = Field(default_factory=uuid_string, description="Identifier for this batch")
batches: List[Batch] = Field( batches: List[Batch] = Field(
@ -134,33 +136,24 @@ class BatchProcessStorageBase(ABC):
@abstractmethod @abstractmethod
def create_session( def create_session(
self, self,
session: BatchSession, session: BatchSession,
) -> BatchSession: ) -> BatchSession:
"""Creates a Batch Session attached to a Batch Process.""" """Creates a Batch Session attached to a Batch Process."""
pass pass
@abstractmethod @abstractmethod
def get_session( def get_session(self, session_id: str) -> BatchSession:
self,
session_id: str
) -> BatchSession:
"""Gets session by session_id""" """Gets session by session_id"""
pass pass
@abstractmethod @abstractmethod
def get_created_session( def get_created_session(self, batch_id: str) -> BatchSession:
self,
batch_id: str
) -> BatchSession:
"""Gets all created Batch Sessions for a given Batch Process id.""" """Gets all created Batch Sessions for a given Batch Process id."""
pass pass
@abstractmethod @abstractmethod
def get_created_sessions( def get_created_sessions(self, batch_id: str) -> List[BatchSession]:
self,
batch_id: str
) -> List[BatchSession]:
"""Gets all created Batch Sessions for a given Batch Process id.""" """Gets all created Batch Sessions for a given Batch Process id."""
pass pass
@ -339,10 +332,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
batches = json.loads(batches_raw) batches = json.loads(batches_raw)
batches = [parse_raw_as(Batch, batch) for batch in batches] batches = [parse_raw_as(Batch, batch) for batch in batches]
return BatchProcess( return BatchProcess(
batch_id=batch_id, batch_id=batch_id, batches=batches, graph=parse_raw_as(Graph, graph_raw), canceled=canceled == 1
batches=batches,
graph=parse_raw_as(Graph, graph_raw),
canceled = canceled == 1
) )
def get( def get(
@ -357,7 +347,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
FROM batch_process FROM batch_process
WHERE batch_id = ?; WHERE batch_id = ?;
""", """,
(batch_id,) (batch_id,),
) )
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone()) result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
@ -370,7 +360,6 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
raise BatchProcessNotFoundException raise BatchProcessNotFoundException
return self._deserialize_batch_process(dict(result)) return self._deserialize_batch_process(dict(result))
def cancel( def cancel(
self, self,
batch_id: str, batch_id: str,
@ -393,8 +382,8 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
self._lock.release() self._lock.release()
def create_session( def create_session(
self, self,
session: BatchSession, session: BatchSession,
) -> BatchSession: ) -> BatchSession:
try: try:
self._lock.acquire() self._lock.acquire()
@ -413,11 +402,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
self._lock.release() self._lock.release()
return self.get_session(session.session_id) return self.get_session(session.session_id)
def get_session(self, session_id: str) -> BatchSession:
def get_session(
self,
session_id: str
) -> BatchSession:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
@ -454,11 +439,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
state=state, state=state,
) )
def get_created_session(self, batch_id: str) -> BatchSession:
def get_created_session(
self,
batch_id: str
) -> BatchSession:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
@ -481,11 +462,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
session = self._deserialize_batch_session(dict(result)) session = self._deserialize_batch_session(dict(result))
return session return session
def get_created_sessions(self, batch_id: str) -> List[BatchSession]:
def get_created_sessions(
self,
batch_id: str
) -> List[BatchSession]:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
@ -497,7 +474,6 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
(batch_id,), (batch_id,),
) )
result = cast(list[sqlite3.Row], self._cursor.fetchall()) result = cast(list[sqlite3.Row], self._cursor.fetchall())
except sqlite3.Error as e: except sqlite3.Error as e:
self._conn.rollback() self._conn.rollback()
@ -508,7 +484,6 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
raise BatchSessionNotFoundException raise BatchSessionNotFoundException
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result)) sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
return sessions return sessions
def update_session_state( def update_session_state(
self, self,
@ -536,4 +511,4 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
raise BatchSessionSaveException from e raise BatchSessionSaveException from e
finally: finally:
self._lock.release() self._lock.release()
return self.get_session(session_id) return self.get_session(session_id)

View File

@ -29,7 +29,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._conn = sqlite3.connect( self._conn = sqlite3.connect(
self._filename, check_same_thread=False self._filename, check_same_thread=False
) # TODO: figure out a better threading solution ) # TODO: figure out a better threading solution
self._conn.execute('pragma journal_mode=wal') self._conn.execute("pragma journal_mode=wal")
self._cursor = self._conn.cursor() self._cursor = self._conn.cursor()
self._create_table() self._create_table()

View File

@ -39,7 +39,7 @@ def simple_batches():
batches = [ batches = [
Batch( Batch(
node_id=1, node_id=1,
data= [ data=[
{ {
"prompt": "Tomato sushi", "prompt": "Tomato sushi",
}, },
@ -55,11 +55,11 @@ def simple_batches():
{ {
"prompt": "Tea sushi", "prompt": "Tea sushi",
}, },
] ],
), ),
Batch( Batch(
node_id="2", node_id="2",
data= [ data=[
{ {
"prompt2": "Ume sushi", "prompt2": "Ume sushi",
}, },
@ -75,8 +75,8 @@ def simple_batches():
{ {
"prompt2": "Cha sushi", "prompt2": "Cha sushi",
}, },
] ],
) ),
] ]
return batches return batches
@ -185,6 +185,7 @@ def test_handles_errors(mock_invoker: Invoker):
assert all((i in g.errors for i in g.source_prepared_mapping["1"])) assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batches): def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batches):
batch_process_res = mock_invoker.services.batch_manager.create_batch_process( batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
batches=simple_batches, batches=simple_batches,