mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Run python black
This commit is contained in:
parent
796ee1246b
commit
f7277a8b21
@ -363,7 +363,10 @@ class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
|||||||
graph_raw = session_dict.get("graph", "unknown")
|
graph_raw = session_dict.get("graph", "unknown")
|
||||||
canceled = session_dict.get("canceled", 0)
|
canceled = session_dict.get("canceled", 0)
|
||||||
return BatchProcess(
|
return BatchProcess(
|
||||||
batch_id=batch_id, batch=parse_raw_as(Batch, batch_raw), graph=parse_raw_as(Graph, graph_raw), canceled=canceled == 1
|
batch_id=batch_id,
|
||||||
|
batch=parse_raw_as(Batch, batch_raw),
|
||||||
|
graph=parse_raw_as(Graph, graph_raw),
|
||||||
|
canceled=canceled == 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
|
@ -27,6 +27,7 @@ from invokeai.app.services.graph import (
|
|||||||
import pytest
|
import pytest
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def simple_graph():
|
def simple_graph():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
@ -43,9 +44,7 @@ def simple_graph():
|
|||||||
def mock_services() -> InvocationServices:
|
def mock_services() -> InvocationServices:
|
||||||
# NOTE: none of these are actually called by the test invocations
|
# NOTE: none of these are actually called by the test invocations
|
||||||
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
||||||
conn=db_conn, table_name="graph_executions"
|
|
||||||
)
|
|
||||||
return InvocationServices(
|
return InvocationServices(
|
||||||
model_manager=None, # type: ignore
|
model_manager=None, # type: ignore
|
||||||
events=TestEventService(),
|
events=TestEventService(),
|
||||||
|
@ -4,6 +4,7 @@ from pydantic import BaseModel, Field
|
|||||||
import pytest
|
import pytest
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
|
||||||
|
|
||||||
class TestModel(BaseModel):
|
class TestModel(BaseModel):
|
||||||
id: str = Field(description="ID")
|
id: str = Field(description="ID")
|
||||||
name: str = Field(description="Name")
|
name: str = Field(description="Name")
|
||||||
@ -14,6 +15,7 @@ def db() -> SqliteItemStorage[TestModel]:
|
|||||||
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
||||||
return SqliteItemStorage[TestModel](db_conn, "test", "id")
|
return SqliteItemStorage[TestModel](db_conn, "test", "id")
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_can_create_and_get(db: SqliteItemStorage[TestModel]):
|
def test_sqlite_service_can_create_and_get(db: SqliteItemStorage[TestModel]):
|
||||||
db.set(TestModel(id="1", name="Test"))
|
db.set(TestModel(id="1", name="Test"))
|
||||||
assert db.get("1") == TestModel(id="1", name="Test")
|
assert db.get("1") == TestModel(id="1", name="Test")
|
||||||
|
Loading…
Reference in New Issue
Block a user