mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Run python black
This commit is contained in:
parent
846e52f2ea
commit
f8d8b16267
@ -19,6 +19,7 @@ from invokeai.app.services.batch_manager_storage import (
|
|||||||
BatchSessionChanges,
|
BatchSessionChanges,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BatchProcessResponse(BaseModel):
|
class BatchProcessResponse(BaseModel):
|
||||||
batch_id: str = Field(description="ID for the batch")
|
batch_id: str = Field(description="ID for the batch")
|
||||||
session_ids: list[str] = Field(description="List of session IDs created for this batch")
|
session_ids: list[str] = Field(description="List of session IDs created for this batch")
|
||||||
@ -75,9 +76,7 @@ class BatchManager(BatchManagerBase):
|
|||||||
batch_session = self.__batch_process_storage.get_session(data["graph_execution_state_id"])
|
batch_session = self.__batch_process_storage.get_session(data["graph_execution_state_id"])
|
||||||
if not batch_session:
|
if not batch_session:
|
||||||
return
|
return
|
||||||
updateSession = BatchSessionChanges(
|
updateSession = BatchSessionChanges(state="error" if err else "completed")
|
||||||
state='error' if err else 'completed'
|
|
||||||
)
|
|
||||||
batch_session = self.__batch_process_storage.update_session_state(
|
batch_session = self.__batch_process_storage.update_session_state(
|
||||||
batch_session.batch_id,
|
batch_session.batch_id,
|
||||||
batch_session.session_id,
|
batch_session.session_id,
|
||||||
@ -105,18 +104,18 @@ class BatchManager(BatchManagerBase):
|
|||||||
return GraphExecutionState(graph=graph)
|
return GraphExecutionState(graph=graph)
|
||||||
|
|
||||||
def run_batch_process(self, batch_id: str):
|
def run_batch_process(self, batch_id: str):
|
||||||
try:
|
try:
|
||||||
created_session = self.__batch_process_storage.get_created_session(batch_id)
|
created_session = self.__batch_process_storage.get_created_session(batch_id)
|
||||||
except BatchSessionNotFoundException:
|
except BatchSessionNotFoundException:
|
||||||
return
|
return
|
||||||
ges = self.__invoker.services.graph_execution_manager.get(created_session.session_id)
|
ges = self.__invoker.services.graph_execution_manager.get(created_session.session_id)
|
||||||
self.__invoker.invoke(ges, invoke_all=True)
|
self.__invoker.invoke(ges, invoke_all=True)
|
||||||
|
|
||||||
def _valid_batch_config(self, batch_process: BatchProcess) -> bool:
|
def _valid_batch_config(self, batch_process: BatchProcess) -> bool:
|
||||||
# TODO: Check that the node_ids in the batches are unique
|
# TODO: Check that the node_ids in the batches are unique
|
||||||
# TODO: Validate data types are correct for each batch data
|
# TODO: Validate data types are correct for each batch data
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def create_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcessResponse:
|
def create_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcessResponse:
|
||||||
batch_process = BatchProcess(
|
batch_process = BatchProcess(
|
||||||
batches=batches,
|
batches=batches,
|
||||||
@ -130,7 +129,7 @@ class BatchManager(BatchManagerBase):
|
|||||||
batch_id=batch_process.batch_id,
|
batch_id=batch_process.batch_id,
|
||||||
session_ids=[session.session_id for session in sessions],
|
session_ids=[session.session_id for session in sessions],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_sessions(self, batch_process: BatchProcess) -> list[BatchSession]:
|
def _create_sessions(self, batch_process: BatchProcess) -> list[BatchSession]:
|
||||||
batch_indices = list()
|
batch_indices = list()
|
||||||
sessions = list()
|
sessions = list()
|
||||||
@ -140,11 +139,7 @@ class BatchManager(BatchManagerBase):
|
|||||||
for bi in all_batch_indices:
|
for bi in all_batch_indices:
|
||||||
ges = self._create_batch_session(batch_process, bi)
|
ges = self._create_batch_session(batch_process, bi)
|
||||||
self.__invoker.services.graph_execution_manager.set(ges)
|
self.__invoker.services.graph_execution_manager.set(ges)
|
||||||
batch_session = BatchSession(
|
batch_session = BatchSession(batch_id=batch_process.batch_id, session_id=ges.id, state="created")
|
||||||
batch_id=batch_process.batch_id,
|
|
||||||
session_id=ges.id,
|
|
||||||
state="created"
|
|
||||||
)
|
|
||||||
sessions.append(self.__batch_process_storage.create_session(batch_session))
|
sessions.append(self.__batch_process_storage.create_session(batch_session))
|
||||||
return sessions
|
return sessions
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ InvocationsUnion = Union[invocations] # type: ignore
|
|||||||
|
|
||||||
BatchDataType = Union[str, int, float, ImageField]
|
BatchDataType = Union[str, int, float, ImageField]
|
||||||
|
|
||||||
|
|
||||||
class Batch(BaseModel):
|
class Batch(BaseModel):
|
||||||
data: list[dict[str, BatchDataType]] = Field(description="Mapping of node field to data value")
|
data: list[dict[str, BatchDataType]] = Field(description="Mapping of node field to data value")
|
||||||
node_id: str = Field(description="ID of the node to batch")
|
node_id: str = Field(description="ID of the node to batch")
|
||||||
@ -42,6 +43,7 @@ def uuid_string():
|
|||||||
res = uuid.uuid4()
|
res = uuid.uuid4()
|
||||||
return str(res)
|
return str(res)
|
||||||
|
|
||||||
|
|
||||||
class BatchProcess(BaseModel):
|
class BatchProcess(BaseModel):
|
||||||
batch_id: Optional[str] = Field(default_factory=uuid_string, description="Identifier for this batch")
|
batch_id: Optional[str] = Field(default_factory=uuid_string, description="Identifier for this batch")
|
||||||
batches: List[Batch] = Field(
|
batches: List[Batch] = Field(
|
||||||
@ -134,33 +136,24 @@ class BatchProcessStorageBase(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_session(
|
def create_session(
|
||||||
self,
|
self,
|
||||||
session: BatchSession,
|
session: BatchSession,
|
||||||
) -> BatchSession:
|
) -> BatchSession:
|
||||||
"""Creates a Batch Session attached to a Batch Process."""
|
"""Creates a Batch Session attached to a Batch Process."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_session(
|
def get_session(self, session_id: str) -> BatchSession:
|
||||||
self,
|
|
||||||
session_id: str
|
|
||||||
) -> BatchSession:
|
|
||||||
"""Gets session by session_id"""
|
"""Gets session by session_id"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_created_session(
|
def get_created_session(self, batch_id: str) -> BatchSession:
|
||||||
self,
|
|
||||||
batch_id: str
|
|
||||||
) -> BatchSession:
|
|
||||||
"""Gets all created Batch Sessions for a given Batch Process id."""
|
"""Gets all created Batch Sessions for a given Batch Process id."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_created_sessions(
|
def get_created_sessions(self, batch_id: str) -> List[BatchSession]:
|
||||||
self,
|
|
||||||
batch_id: str
|
|
||||||
) -> List[BatchSession]:
|
|
||||||
"""Gets all created Batch Sessions for a given Batch Process id."""
|
"""Gets all created Batch Sessions for a given Batch Process id."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -339,10 +332,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
batches = json.loads(batches_raw)
|
batches = json.loads(batches_raw)
|
||||||
batches = [parse_raw_as(Batch, batch) for batch in batches]
|
batches = [parse_raw_as(Batch, batch) for batch in batches]
|
||||||
return BatchProcess(
|
return BatchProcess(
|
||||||
batch_id=batch_id,
|
batch_id=batch_id, batches=batches, graph=parse_raw_as(Graph, graph_raw), canceled=canceled == 1
|
||||||
batches=batches,
|
|
||||||
graph=parse_raw_as(Graph, graph_raw),
|
|
||||||
canceled = canceled == 1
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
@ -357,7 +347,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
FROM batch_process
|
FROM batch_process
|
||||||
WHERE batch_id = ?;
|
WHERE batch_id = ?;
|
||||||
""",
|
""",
|
||||||
(batch_id,)
|
(batch_id,),
|
||||||
)
|
)
|
||||||
|
|
||||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||||
@ -370,7 +360,6 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
raise BatchProcessNotFoundException
|
raise BatchProcessNotFoundException
|
||||||
return self._deserialize_batch_process(dict(result))
|
return self._deserialize_batch_process(dict(result))
|
||||||
|
|
||||||
|
|
||||||
def cancel(
|
def cancel(
|
||||||
self,
|
self,
|
||||||
batch_id: str,
|
batch_id: str,
|
||||||
@ -393,8 +382,8 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
def create_session(
|
def create_session(
|
||||||
self,
|
self,
|
||||||
session: BatchSession,
|
session: BatchSession,
|
||||||
) -> BatchSession:
|
) -> BatchSession:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
@ -413,11 +402,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
self._lock.release()
|
self._lock.release()
|
||||||
return self.get_session(session.session_id)
|
return self.get_session(session.session_id)
|
||||||
|
|
||||||
|
def get_session(self, session_id: str) -> BatchSession:
|
||||||
def get_session(
|
|
||||||
self,
|
|
||||||
session_id: str
|
|
||||||
) -> BatchSession:
|
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
@ -454,11 +439,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
state=state,
|
state=state,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_created_session(self, batch_id: str) -> BatchSession:
|
||||||
def get_created_session(
|
|
||||||
self,
|
|
||||||
batch_id: str
|
|
||||||
) -> BatchSession:
|
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
@ -481,11 +462,7 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
session = self._deserialize_batch_session(dict(result))
|
session = self._deserialize_batch_session(dict(result))
|
||||||
return session
|
return session
|
||||||
|
|
||||||
|
def get_created_sessions(self, batch_id: str) -> List[BatchSession]:
|
||||||
def get_created_sessions(
|
|
||||||
self,
|
|
||||||
batch_id: str
|
|
||||||
) -> List[BatchSession]:
|
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
@ -497,7 +474,6 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
(batch_id,),
|
(batch_id,),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
self._conn.rollback()
|
self._conn.rollback()
|
||||||
@ -508,7 +484,6 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
raise BatchSessionNotFoundException
|
raise BatchSessionNotFoundException
|
||||||
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
|
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
|
||||||
return sessions
|
return sessions
|
||||||
|
|
||||||
|
|
||||||
def update_session_state(
|
def update_session_state(
|
||||||
self,
|
self,
|
||||||
@ -536,4 +511,4 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
raise BatchSessionSaveException from e
|
raise BatchSessionSaveException from e
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
return self.get_session(session_id)
|
return self.get_session(session_id)
|
||||||
|
@ -29,7 +29,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
self._conn = sqlite3.connect(
|
self._conn = sqlite3.connect(
|
||||||
self._filename, check_same_thread=False
|
self._filename, check_same_thread=False
|
||||||
) # TODO: figure out a better threading solution
|
) # TODO: figure out a better threading solution
|
||||||
self._conn.execute('pragma journal_mode=wal')
|
self._conn.execute("pragma journal_mode=wal")
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
|
|
||||||
self._create_table()
|
self._create_table()
|
||||||
|
@ -39,7 +39,7 @@ def simple_batches():
|
|||||||
batches = [
|
batches = [
|
||||||
Batch(
|
Batch(
|
||||||
node_id=1,
|
node_id=1,
|
||||||
data= [
|
data=[
|
||||||
{
|
{
|
||||||
"prompt": "Tomato sushi",
|
"prompt": "Tomato sushi",
|
||||||
},
|
},
|
||||||
@ -55,11 +55,11 @@ def simple_batches():
|
|||||||
{
|
{
|
||||||
"prompt": "Tea sushi",
|
"prompt": "Tea sushi",
|
||||||
},
|
},
|
||||||
]
|
],
|
||||||
),
|
),
|
||||||
Batch(
|
Batch(
|
||||||
node_id="2",
|
node_id="2",
|
||||||
data= [
|
data=[
|
||||||
{
|
{
|
||||||
"prompt2": "Ume sushi",
|
"prompt2": "Ume sushi",
|
||||||
},
|
},
|
||||||
@ -75,8 +75,8 @@ def simple_batches():
|
|||||||
{
|
{
|
||||||
"prompt2": "Cha sushi",
|
"prompt2": "Cha sushi",
|
||||||
},
|
},
|
||||||
]
|
],
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
return batches
|
return batches
|
||||||
|
|
||||||
@ -185,6 +185,7 @@ def test_handles_errors(mock_invoker: Invoker):
|
|||||||
|
|
||||||
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
|
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
|
||||||
|
|
||||||
|
|
||||||
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batches):
|
def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batches):
|
||||||
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
batch_process_res = mock_invoker.services.batch_manager.create_batch_process(
|
||||||
batches=simple_batches,
|
batches=simple_batches,
|
||||||
|
Loading…
Reference in New Issue
Block a user