Run python black

This commit is contained in:
Brandon Rising 2023-08-16 15:44:52 -04:00
parent 796ee1246b
commit f7277a8b21
3 changed files with 8 additions and 4 deletions

View File

@ -363,7 +363,10 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
graph_raw = session_dict.get("graph", "unknown") graph_raw = session_dict.get("graph", "unknown")
canceled = session_dict.get("canceled", 0) canceled = session_dict.get("canceled", 0)
return BatchProcess( return BatchProcess(
batch_id=batch_id, batch=parse_raw_as(Batch, batch_raw), graph=parse_raw_as(Graph, graph_raw), canceled=canceled == 1 batch_id=batch_id,
batch=parse_raw_as(Batch, batch_raw),
graph=parse_raw_as(Graph, graph_raw),
canceled=canceled == 1,
) )
def get( def get(

View File

@ -27,6 +27,7 @@ from invokeai.app.services.graph import (
import pytest import pytest
import sqlite3 import sqlite3
@pytest.fixture @pytest.fixture
def simple_graph(): def simple_graph():
g = Graph() g = Graph()
@ -43,9 +44,7 @@ def simple_graph():
def mock_services() -> InvocationServices: def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations # NOTE: none of these are actually called by the test invocations
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False) db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
graph_execution_manager = SqliteItemStorage[GraphExecutionState]( graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
conn=db_conn, table_name="graph_executions"
)
return InvocationServices( return InvocationServices(
model_manager=None, # type: ignore model_manager=None, # type: ignore
events=TestEventService(), events=TestEventService(),

View File

@ -4,6 +4,7 @@ from pydantic import BaseModel, Field
import pytest import pytest
import sqlite3 import sqlite3
class TestModel(BaseModel): class TestModel(BaseModel):
id: str = Field(description="ID") id: str = Field(description="ID")
name: str = Field(description="Name") name: str = Field(description="Name")
@ -14,6 +15,7 @@ def db() -> SqliteItemStorage[TestModel]:
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False) db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
return SqliteItemStorage[TestModel](db_conn, "test", "id") return SqliteItemStorage[TestModel](db_conn, "test", "id")
def test_sqlite_service_can_create_and_get(db: SqliteItemStorage[TestModel]): def test_sqlite_service_can_create_and_get(db: SqliteItemStorage[TestModel]):
db.set(TestModel(id="1", name="Test")) db.set(TestModel(id="1", name="Test"))
assert db.get("1") == TestModel(id="1", name="Test") assert db.get("1") == TestModel(id="1", name="Test")