mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'ryan/remove-attention-map-saving' into ryan/regional-conditioning
This commit is contained in:
@ -1,174 +0,0 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
||||
|
||||
# This import must happen before other invoke imports or test in other files(!!) break
|
||||
from .test_nodes import ( # isort: split
|
||||
ErrorInvocation,
|
||||
PromptTestInvocation,
|
||||
TestEventService,
|
||||
TextToImageTestInvocation,
|
||||
create_edge,
|
||||
wait_until,
|
||||
)
|
||||
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor
|
||||
from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_graph():
|
||||
g = Graph()
|
||||
g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
|
||||
g.add_node(TextToImageTestInvocation(id="2"))
|
||||
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
|
||||
return g
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graph_with_subgraph():
|
||||
sub_g = Graph()
|
||||
sub_g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
|
||||
sub_g.add_node(TextToImageTestInvocation(id="2"))
|
||||
sub_g.add_edge(create_edge("1", "prompt", "2", "prompt"))
|
||||
g = Graph()
|
||||
g.add_node(GraphInvocation(id="1", graph=sub_g))
|
||||
return g
|
||||
|
||||
|
||||
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
|
||||
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
|
||||
# the test invocations.
|
||||
@pytest.fixture
|
||||
def mock_services() -> InvocationServices:
|
||||
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
|
||||
return InvocationServices(
|
||||
board_image_records=None, # type: ignore
|
||||
board_images=None, # type: ignore
|
||||
board_records=None, # type: ignore
|
||||
boards=None, # type: ignore
|
||||
configuration=configuration,
|
||||
events=TestEventService(),
|
||||
graph_execution_manager=ItemStorageMemory[GraphExecutionState](),
|
||||
image_files=None, # type: ignore
|
||||
image_records=None, # type: ignore
|
||||
images=None, # type: ignore
|
||||
invocation_cache=MemoryInvocationCache(max_cache_size=0),
|
||||
latents=None, # type: ignore
|
||||
logger=logging, # type: ignore
|
||||
model_manager=None, # type: ignore
|
||||
model_records=None, # type: ignore
|
||||
download_queue=None, # type: ignore
|
||||
model_install=None, # type: ignore
|
||||
names=None, # type: ignore
|
||||
performance_statistics=InvocationStatsService(),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
queue=MemoryInvocationQueue(),
|
||||
session_processor=None, # type: ignore
|
||||
session_queue=None, # type: ignore
|
||||
urls=None, # type: ignore
|
||||
workflow_records=None, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_invoker(mock_services: InvocationServices) -> Invoker:
|
||||
return Invoker(services=mock_services)
|
||||
|
||||
|
||||
def test_can_create_graph_state(mock_invoker: Invoker):
|
||||
g = mock_invoker.create_execution_state()
|
||||
mock_invoker.stop()
|
||||
|
||||
assert g is not None
|
||||
assert isinstance(g, GraphExecutionState)
|
||||
|
||||
|
||||
def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
|
||||
g = mock_invoker.create_execution_state(graph=simple_graph)
|
||||
mock_invoker.stop()
|
||||
|
||||
assert g is not None
|
||||
assert isinstance(g, GraphExecutionState)
|
||||
assert g.graph == simple_graph
|
||||
|
||||
|
||||
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
|
||||
def test_can_invoke(mock_invoker: Invoker, simple_graph):
|
||||
g = mock_invoker.create_execution_state(graph=simple_graph)
|
||||
invocation_id = mock_invoker.invoke(
|
||||
session_queue_batch_id="1",
|
||||
session_queue_item_id=1,
|
||||
session_queue_id=DEFAULT_QUEUE_ID,
|
||||
graph_execution_state=g,
|
||||
)
|
||||
assert invocation_id is not None
|
||||
|
||||
def has_executed_any(g: GraphExecutionState):
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
return len(g.executed) > 0
|
||||
|
||||
wait_until(lambda: has_executed_any(g), timeout=5, interval=1)
|
||||
mock_invoker.stop()
|
||||
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
assert len(g.executed) > 0
|
||||
|
||||
|
||||
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
|
||||
def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
|
||||
g = mock_invoker.create_execution_state(graph=simple_graph)
|
||||
invocation_id = mock_invoker.invoke(
|
||||
session_queue_batch_id="1",
|
||||
session_queue_item_id=1,
|
||||
session_queue_id=DEFAULT_QUEUE_ID,
|
||||
graph_execution_state=g,
|
||||
invoke_all=True,
|
||||
)
|
||||
assert invocation_id is not None
|
||||
|
||||
def has_executed_all(g: GraphExecutionState):
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
return g.is_complete()
|
||||
|
||||
wait_until(lambda: has_executed_all(g), timeout=5, interval=1)
|
||||
mock_invoker.stop()
|
||||
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
assert g.is_complete()
|
||||
|
||||
|
||||
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
|
||||
def test_handles_errors(mock_invoker: Invoker):
|
||||
g = mock_invoker.create_execution_state()
|
||||
g.graph.add_node(ErrorInvocation(id="1"))
|
||||
|
||||
mock_invoker.invoke(
|
||||
session_queue_batch_id="1",
|
||||
session_queue_item_id=1,
|
||||
session_queue_id=DEFAULT_QUEUE_ID,
|
||||
graph_execution_state=g,
|
||||
invoke_all=True,
|
||||
)
|
||||
|
||||
def has_executed_all(g: GraphExecutionState):
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
return g.is_complete()
|
||||
|
||||
wait_until(lambda: has_executed_all(g), timeout=5, interval=1)
|
||||
mock_invoker.stop()
|
||||
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
assert g.has_error()
|
||||
assert g.is_complete()
|
||||
|
||||
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
|
119
tests/app/routers/test_images.py
Normal file
119
tests/app/routers/test_images.py
Normal file
@ -0,0 +1,119 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from fastapi import BackgroundTasks
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.api_app import app
|
||||
from invokeai.app.services.board_records.board_records_common import BoardRecord
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def client(invokeai_root_dir: Path) -> TestClient:
|
||||
os.environ["INVOKEAI_ROOT"] = invokeai_root_dir.as_posix()
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
class MockApiDependencies(ApiDependencies):
|
||||
invoker: Invoker
|
||||
|
||||
def __init__(self, invoker) -> None:
|
||||
self.invoker = invoker
|
||||
|
||||
|
||||
def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None:
|
||||
prepare_download_images_test(monkeypatch, mock_invoker)
|
||||
|
||||
response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]})
|
||||
json_response = response.json()
|
||||
assert response.status_code == 202
|
||||
assert json_response["bulk_download_item_name"] == "test.zip"
|
||||
|
||||
|
||||
def test_download_images_from_board_id_empty_image_name_list(
|
||||
monkeypatch: Any, mock_invoker: Invoker, client: TestClient
|
||||
) -> None:
|
||||
expected_board_name = "test"
|
||||
|
||||
def mock_get(*args, **kwargs):
|
||||
return BoardRecord(board_id="12345", board_name=expected_board_name, created_at="None", updated_at="None")
|
||||
|
||||
monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_get)
|
||||
prepare_download_images_test(monkeypatch, mock_invoker)
|
||||
|
||||
response = client.post("/api/v1/images/download", json={"board_id": "test"})
|
||||
json_response = response.json()
|
||||
assert response.status_code == 202
|
||||
assert json_response["bulk_download_item_name"] == "test.zip"
|
||||
|
||||
|
||||
def prepare_download_images_test(monkeypatch: Any, mock_invoker: Invoker) -> None:
|
||||
monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker))
|
||||
monkeypatch.setattr(
|
||||
"invokeai.app.api.routers.images.ApiDependencies.invoker.services.bulk_download.generate_item_id",
|
||||
lambda arg: "test",
|
||||
)
|
||||
|
||||
def mock_add_task(*args, **kwargs):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(BackgroundTasks, "add_task", mock_add_task)
|
||||
|
||||
|
||||
def test_download_images_with_empty_image_list_and_no_board_id(
|
||||
monkeypatch: Any, mock_invoker: Invoker, client: TestClient
|
||||
) -> None:
|
||||
prepare_download_images_test(monkeypatch, mock_invoker)
|
||||
|
||||
response = client.post("/api/v1/images/download", json={"image_names": []})
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_get_bulk_download_image(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None:
|
||||
mock_file: Path = tmp_path / "test.zip"
|
||||
mock_file.write_text("contents")
|
||||
|
||||
monkeypatch.setattr(mock_invoker.services.bulk_download, "get_path", lambda x: str(mock_file))
|
||||
monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker))
|
||||
|
||||
def mock_add_task(*args, **kwargs):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(BackgroundTasks, "add_task", mock_add_task)
|
||||
|
||||
response = client.get("/api/v1/images/download/test.zip")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"contents"
|
||||
|
||||
|
||||
def test_get_bulk_download_image_not_found(monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None:
|
||||
monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker))
|
||||
|
||||
def mock_add_task(*args, **kwargs):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(BackgroundTasks, "add_task", mock_add_task)
|
||||
|
||||
response = client.get("/api/v1/images/download/test.zip")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_get_bulk_download_image_image_deleted_after_response(
|
||||
monkeypatch: Any, mock_invoker: Invoker, tmp_path: Path, client: TestClient
|
||||
) -> None:
|
||||
mock_file: Path = tmp_path / "test.zip"
|
||||
mock_file.write_text("contents")
|
||||
|
||||
monkeypatch.setattr(mock_invoker.services.bulk_download, "get_path", lambda x: str(mock_file))
|
||||
monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker))
|
||||
|
||||
client.get("/api/v1/images/download/test.zip")
|
||||
|
||||
assert not (tmp_path / "test.zip").exists()
|
379
tests/app/services/bulk_download/test_bulk_download.py
Normal file
379
tests/app/services/bulk_download/test_bulk_download.py
Normal file
@ -0,0 +1,379 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any
|
||||
from zipfile import ZipFile
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.board_records.board_records_common import BoardRecord, BoardRecordNotFoundException
|
||||
from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadTargetException
|
||||
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageRecordNotFoundException,
|
||||
ResourceOrigin,
|
||||
)
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from tests.test_nodes import TestEventService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_image_dto() -> ImageDTO:
|
||||
"""Create a mock ImageDTO."""
|
||||
return ImageDTO(
|
||||
image_name="mock_image.png",
|
||||
board_id="12345",
|
||||
image_url="None",
|
||||
width=100,
|
||||
height=100,
|
||||
thumbnail_url="None",
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
created_at="None",
|
||||
updated_at="None",
|
||||
starred=False,
|
||||
has_workflow=False,
|
||||
is_intermediate=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_temporary_directory(monkeypatch: Any, tmp_path: Path):
|
||||
"""Mock the TemporaryDirectory class so that it uses the tmp_path fixture."""
|
||||
|
||||
class MockTemporaryDirectory(TemporaryDirectory):
|
||||
def __init__(self):
|
||||
super().__init__(dir=tmp_path)
|
||||
self.name = tmp_path
|
||||
|
||||
def mock_TemporaryDirectory(*args, **kwargs):
|
||||
return MockTemporaryDirectory()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"invokeai.app.services.bulk_download.bulk_download_default.TemporaryDirectory", mock_TemporaryDirectory
|
||||
)
|
||||
|
||||
|
||||
def test_get_path_when_file_exists(tmp_path: Path) -> None:
|
||||
"""Test get_path when the file exists."""
|
||||
|
||||
bulk_download_service = BulkDownloadService()
|
||||
|
||||
# Create a directory at tmp_path/bulk_downloads
|
||||
test_bulk_downloads_dir: Path = tmp_path / "bulk_downloads"
|
||||
test_bulk_downloads_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a file at tmp_path/bulk_downloads/test.zip
|
||||
test_file_path: Path = test_bulk_downloads_dir / "test.zip"
|
||||
test_file_path.touch()
|
||||
|
||||
assert bulk_download_service.get_path("test.zip") == str(test_file_path)
|
||||
|
||||
|
||||
def test_get_path_when_file_does_not_exist(tmp_path: Path) -> None:
|
||||
"""Test get_path when the file does not exist."""
|
||||
|
||||
bulk_download_service = BulkDownloadService()
|
||||
with pytest.raises(BulkDownloadTargetException):
|
||||
bulk_download_service.get_path("test")
|
||||
|
||||
|
||||
def test_bulk_downloads_dir_created_at_start(tmp_path: Path) -> None:
|
||||
"""Test that the bulk_downloads directory is created at start."""
|
||||
|
||||
BulkDownloadService()
|
||||
assert (tmp_path / "bulk_downloads").exists()
|
||||
|
||||
|
||||
def test_handler_image_names(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
|
||||
"""Test that the handler creates the zip file correctly when given a list of image names."""
|
||||
|
||||
expected_zip_path, expected_image_path, mock_image_contents = prepare_handler_test(
|
||||
tmp_path, monkeypatch, mock_image_dto, mock_invoker
|
||||
)
|
||||
|
||||
bulk_download_service = BulkDownloadService()
|
||||
bulk_download_service.start(mock_invoker)
|
||||
bulk_download_service.handler([mock_image_dto.image_name], None, None)
|
||||
|
||||
assert_handler_success(
|
||||
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
|
||||
)
|
||||
|
||||
|
||||
def test_generate_id(monkeypatch: Any):
|
||||
"""Test that the generate_id method generates a unique id."""
|
||||
|
||||
bulk_download_service = BulkDownloadService()
|
||||
|
||||
monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test")
|
||||
|
||||
assert bulk_download_service.generate_item_id(None) == "test"
|
||||
|
||||
|
||||
def test_generate_id_with_board_id(monkeypatch: Any, mock_invoker: Invoker):
|
||||
"""Test that the generate_id method generates a unique id with a board id."""
|
||||
|
||||
bulk_download_service = BulkDownloadService()
|
||||
bulk_download_service.start(mock_invoker)
|
||||
|
||||
def mock_board_get(*args, **kwargs):
|
||||
return BoardRecord(board_id="12345", board_name="test_board_name", created_at="None", updated_at="None")
|
||||
|
||||
monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_board_get)
|
||||
|
||||
monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test")
|
||||
|
||||
assert bulk_download_service.generate_item_id("12345") == "test_board_name_test"
|
||||
|
||||
|
||||
def test_generate_id_with_default_board_id(monkeypatch: Any):
|
||||
"""Test that the generate_id method generates a unique id with a board id."""
|
||||
|
||||
bulk_download_service = BulkDownloadService()
|
||||
|
||||
monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test")
|
||||
|
||||
assert bulk_download_service.generate_item_id("none") == "Uncategorized_test"
|
||||
|
||||
|
||||
def test_handler_board_id(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
|
||||
"""Test that the handler creates the zip file correctly when given a board id."""
|
||||
|
||||
expected_zip_path, expected_image_path, mock_image_contents = prepare_handler_test(
|
||||
tmp_path, monkeypatch, mock_image_dto, mock_invoker
|
||||
)
|
||||
|
||||
def mock_board_get(*args, **kwargs):
|
||||
return BoardRecord(board_id="12345", board_name="test_board_name", created_at="None", updated_at="None")
|
||||
|
||||
monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_board_get)
|
||||
|
||||
def mock_get_many(*args, **kwargs):
|
||||
return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto])
|
||||
|
||||
monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many)
|
||||
|
||||
bulk_download_service = BulkDownloadService()
|
||||
bulk_download_service.start(mock_invoker)
|
||||
bulk_download_service.handler([], "test", None)
|
||||
|
||||
assert_handler_success(
|
||||
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
|
||||
)
|
||||
|
||||
|
||||
def test_handler_board_id_default(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
|
||||
"""Test that the handler creates the zip file correctly when given a board id."""
|
||||
|
||||
_, expected_image_path, mock_image_contents = prepare_handler_test(
|
||||
tmp_path, monkeypatch, mock_image_dto, mock_invoker
|
||||
)
|
||||
|
||||
def mock_get_many(*args, **kwargs):
|
||||
return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto])
|
||||
|
||||
monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many)
|
||||
|
||||
bulk_download_service = BulkDownloadService()
|
||||
bulk_download_service.start(mock_invoker)
|
||||
bulk_download_service.handler([], "none", None)
|
||||
|
||||
expected_zip_path: Path = tmp_path / "bulk_downloads" / "test.zip"
|
||||
|
||||
assert_handler_success(
|
||||
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
|
||||
)
|
||||
|
||||
|
||||
def test_handler_bulk_download_item_id_given(
|
||||
tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker
|
||||
):
|
||||
"""Test that the handler creates the zip file correctly when given a pregenerated bulk download item id."""
|
||||
|
||||
_, expected_image_path, mock_image_contents = prepare_handler_test(
|
||||
tmp_path, monkeypatch, mock_image_dto, mock_invoker
|
||||
)
|
||||
|
||||
def mock_get_many(*args, **kwargs):
|
||||
return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto])
|
||||
|
||||
monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many)
|
||||
|
||||
bulk_download_service = BulkDownloadService()
|
||||
bulk_download_service.start(mock_invoker)
|
||||
bulk_download_service.handler([mock_image_dto.image_name], None, "test_id")
|
||||
|
||||
expected_zip_path: Path = tmp_path / "bulk_downloads" / "test_id.zip"
|
||||
|
||||
assert_handler_success(
|
||||
expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events
|
||||
)
|
||||
|
||||
|
||||
def prepare_handler_test(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
|
||||
"""Prepare the test for the handler tests."""
|
||||
|
||||
def mock_uuid_string():
|
||||
return "test"
|
||||
|
||||
# You have to patch the function within the module it's being imported into. This is strange, but it works.
|
||||
# See http://www.gregreda.com/2021/06/28/mocking-imported-module-function-python/
|
||||
monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", mock_uuid_string)
|
||||
|
||||
expected_zip_path: Path = tmp_path / "bulk_downloads" / "test.zip"
|
||||
expected_image_path: Path = (
|
||||
tmp_path / "bulk_downloads" / mock_image_dto.image_category.value / mock_image_dto.image_name
|
||||
)
|
||||
|
||||
# Mock the get_dto method so that when the image dto needs to be retrieved it is returned
|
||||
def mock_get_dto(*args, **kwargs):
|
||||
return mock_image_dto
|
||||
|
||||
monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_dto)
|
||||
|
||||
# This is used when preparing all images for a given board
|
||||
def mock_get_all_board_image_names_for_board(*args, **kwargs):
|
||||
return [mock_image_dto.image_name]
|
||||
|
||||
monkeypatch.setattr(
|
||||
mock_invoker.services.board_image_records,
|
||||
"get_all_board_image_names_for_board",
|
||||
mock_get_all_board_image_names_for_board,
|
||||
)
|
||||
|
||||
# Create a mock image file so that the contents of the zip file are not empty
|
||||
mock_image_path: Path = tmp_path / mock_image_dto.image_name
|
||||
mock_image_contents: str = "Totally an image"
|
||||
mock_image_path.write_text(mock_image_contents)
|
||||
|
||||
def mock_get_path(*args, **kwargs):
|
||||
return str(mock_image_path)
|
||||
|
||||
monkeypatch.setattr(mock_invoker.services.images, "get_path", mock_get_path)
|
||||
|
||||
return expected_zip_path, expected_image_path, mock_image_contents
|
||||
|
||||
|
||||
def assert_handler_success(
|
||||
expected_zip_path: Path,
|
||||
expected_image_path: Path,
|
||||
mock_image_contents: str,
|
||||
tmp_path: Path,
|
||||
event_bus: TestEventService,
|
||||
):
|
||||
"""Assert that the handler was successful."""
|
||||
# Check that the zip file was created
|
||||
assert expected_zip_path.exists()
|
||||
assert expected_zip_path.is_file()
|
||||
assert expected_zip_path.stat().st_size > 0
|
||||
|
||||
# Check that the zip contents are expected
|
||||
with ZipFile(expected_zip_path, "r") as zip_file:
|
||||
zip_file.extractall(tmp_path / "bulk_downloads")
|
||||
assert expected_image_path.exists()
|
||||
assert expected_image_path.is_file()
|
||||
assert expected_image_path.stat().st_size > 0
|
||||
assert expected_image_path.read_text() == mock_image_contents
|
||||
|
||||
# Check that the correct events were emitted
|
||||
assert len(event_bus.events) == 2
|
||||
assert event_bus.events[0].event_name == "bulk_download_started"
|
||||
assert event_bus.events[1].event_name == "bulk_download_completed"
|
||||
assert event_bus.events[1].payload["bulk_download_item_name"] == os.path.basename(expected_zip_path)
|
||||
|
||||
|
||||
def test_handler_on_image_not_found(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
|
||||
"""Test that the handler emits an error event when the image is not found."""
|
||||
exception: Exception = ImageRecordNotFoundException("Image not found")
|
||||
|
||||
def mock_get_dto(*args, **kwargs):
|
||||
raise exception
|
||||
|
||||
monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_dto)
|
||||
|
||||
execute_handler_test_on_error(tmp_path, monkeypatch, mock_image_dto, mock_invoker, exception)
|
||||
|
||||
|
||||
def test_handler_on_board_not_found(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
|
||||
"""Test that the handler emits an error event when the image is not found."""
|
||||
|
||||
exception: Exception = BoardRecordNotFoundException("Image not found")
|
||||
|
||||
def mock_get_board_name(*args, **kwargs):
|
||||
raise exception
|
||||
|
||||
monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_board_name)
|
||||
|
||||
execute_handler_test_on_error(tmp_path, monkeypatch, mock_image_dto, mock_invoker, exception)
|
||||
|
||||
|
||||
def test_handler_on_generic_exception(
|
||||
tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker
|
||||
):
|
||||
"""Test that the handler emits an error event when the image is not found."""
|
||||
|
||||
exception: Exception = Exception("Generic exception")
|
||||
|
||||
def mock_get_board_name(*args, **kwargs):
|
||||
raise exception
|
||||
|
||||
monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_board_name)
|
||||
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
execute_handler_test_on_error(tmp_path, monkeypatch, mock_image_dto, mock_invoker, exception)
|
||||
|
||||
event_bus: TestEventService = mock_invoker.services.events
|
||||
|
||||
assert len(event_bus.events) == 2
|
||||
assert event_bus.events[0].event_name == "bulk_download_started"
|
||||
assert event_bus.events[1].event_name == "bulk_download_failed"
|
||||
assert event_bus.events[1].payload["error"] == exception.__str__()
|
||||
|
||||
|
||||
def execute_handler_test_on_error(
|
||||
tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker, error: Exception
|
||||
):
|
||||
bulk_download_service = BulkDownloadService()
|
||||
bulk_download_service.start(mock_invoker)
|
||||
bulk_download_service.handler([mock_image_dto.image_name], None, None)
|
||||
|
||||
event_bus: TestEventService = mock_invoker.services.events
|
||||
|
||||
assert len(event_bus.events) == 2
|
||||
assert event_bus.events[0].event_name == "bulk_download_started"
|
||||
assert event_bus.events[1].event_name == "bulk_download_failed"
|
||||
assert event_bus.events[1].payload["error"] == error.__str__()
|
||||
|
||||
|
||||
def test_delete(tmp_path: Path):
|
||||
"""Test that the delete method removes the bulk download file."""
|
||||
|
||||
bulk_download_service = BulkDownloadService()
|
||||
|
||||
mock_file: Path = tmp_path / "bulk_downloads" / "test.zip"
|
||||
mock_file.write_text("contents")
|
||||
|
||||
bulk_download_service.delete("test.zip")
|
||||
|
||||
assert (tmp_path / "bulk_downloads").exists()
|
||||
assert len(os.listdir(tmp_path / "bulk_downloads")) == 0
|
||||
|
||||
|
||||
def test_stop(tmp_path: Path):
|
||||
"""Test that the stop method removes the bulk download file and not any directories."""
|
||||
|
||||
bulk_download_service = BulkDownloadService()
|
||||
|
||||
mock_file: Path = tmp_path / "bulk_downloads" / "test.zip"
|
||||
mock_file.write_text("contents")
|
||||
|
||||
mock_dir: Path = tmp_path / "bulk_downloads" / "test"
|
||||
mock_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
bulk_download_service.stop()
|
||||
|
||||
assert not (tmp_path / "bulk_downloads").exists()
|
@ -1,20 +1,19 @@
|
||||
"""Test the queued download facility"""
|
||||
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
from requests_testadapter import TestAdapter, TestSession
|
||||
|
||||
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from tests.test_nodes import TestEventService
|
||||
|
||||
# Prevent pytest deprecation warnings
|
||||
TestAdapter.__test__ = False
|
||||
TestAdapter.__test__ = False # type: ignore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -52,28 +51,6 @@ def session() -> Session:
|
||||
return sess
|
||||
|
||||
|
||||
class DummyEvent(BaseModel):
|
||||
"""Dummy Event to use with Dummy Event service."""
|
||||
|
||||
event_name: str
|
||||
payload: Dict[str, Any]
|
||||
|
||||
|
||||
# A dummy event service for testing event issuing
|
||||
class DummyEventService(EventServiceBase):
|
||||
"""Dummy event service for testing."""
|
||||
|
||||
events: List[DummyEvent]
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.events = []
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
"""Dispatch an event by appending it to self.events."""
|
||||
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))
|
||||
|
||||
|
||||
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
|
||||
events = set()
|
||||
|
||||
@ -125,7 +102,7 @@ def test_errors(tmp_path: Path, session: Session) -> None:
|
||||
|
||||
|
||||
def test_event_bus(tmp_path: Path, session: Session) -> None:
|
||||
event_bus = DummyEventService()
|
||||
event_bus = TestEventService()
|
||||
|
||||
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
||||
queue.start()
|
||||
@ -190,8 +167,9 @@ def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
|
||||
queue.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=15, method="thread")
|
||||
def test_cancel(tmp_path: Path, session: Session) -> None:
|
||||
event_bus = DummyEventService()
|
||||
event_bus = TestEventService()
|
||||
|
||||
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
||||
queue.start()
|
||||
@ -205,6 +183,9 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
|
||||
nonlocal cancelled
|
||||
cancelled = True
|
||||
|
||||
def handler(signum, frame):
|
||||
raise TimeoutError("Join took too long to return")
|
||||
|
||||
job = queue.download(
|
||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||
dest=tmp_path,
|
||||
|
@ -20,7 +20,7 @@ from invokeai.app.services.model_install import (
|
||||
)
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType
|
||||
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
|
||||
OS = platform.uname().system
|
||||
|
||||
@ -31,6 +31,7 @@ def test_registration(mm2_installer: ModelInstallServiceBase, embedding_file: Pa
|
||||
assert len(matches) == 0
|
||||
key = mm2_installer.register_path(embedding_file)
|
||||
assert key is not None
|
||||
assert key != "<NOKEY>"
|
||||
assert len(key) == 32
|
||||
|
||||
|
||||
@ -58,12 +59,13 @@ def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase,
|
||||
def test_registration_meta_override_succeed(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
||||
store = mm2_installer.record_store
|
||||
key = mm2_installer.register_path(
|
||||
embedding_file, {"name": "banana_sushi", "source": "fake/repo_id", "current_hash": "New Hash"}
|
||||
embedding_file, {"name": "banana_sushi", "source": "fake/repo_id", "current_hash": "New Hash", "key": "xyzzy"}
|
||||
)
|
||||
model_record = store.get_model(key)
|
||||
assert model_record.name == "banana_sushi"
|
||||
assert model_record.source == "fake/repo_id"
|
||||
assert model_record.current_hash == "New Hash"
|
||||
assert model_record.key == "xyzzy"
|
||||
|
||||
|
||||
def test_install(
|
||||
@ -129,6 +131,7 @@ def test_background_install(
|
||||
model_record = mm2_installer.record_store.get_model(key)
|
||||
assert model_record is not None
|
||||
assert model_record.path == destination
|
||||
assert model_record.key != "<NOKEY>"
|
||||
assert Path(mm2_app_config.models_dir / model_record.path).exists()
|
||||
|
||||
# see if metadata was properly passed through
|
||||
@ -196,6 +199,7 @@ def test_delete_register(
|
||||
store.get_model(key)
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
|
||||
|
||||
@ -221,6 +225,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
|
||||
assert event_names == ["model_install_downloading", "model_install_running", "model_install_completed"]
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
|
||||
@ -256,4 +261,4 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In
|
||||
assert job.error_type == "HTTPError"
|
||||
assert job.error
|
||||
assert "NOT FOUND" in job.error
|
||||
assert "Traceback" in job.error
|
||||
assert job.error_traceback.startswith("Traceback")
|
||||
|
@ -8,6 +8,7 @@ from typing import Any
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
ModelRecordOrderBy,
|
||||
@ -25,7 +26,7 @@ from invokeai.backend.model_manager.config import (
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import BaseMetadata
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||
|
||||
|
||||
@ -36,7 +37,7 @@ def store(
|
||||
config = InvokeAIAppConfig(root=datadir)
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
db = create_mock_sqlite_database(config, logger)
|
||||
return ModelRecordServiceSQL(db)
|
||||
return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
|
||||
|
||||
def example_config() -> TextualInversionConfig:
|
||||
|
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
|
||||
from invokeai.backend.util.test_utils import install_and_load_model
|
||||
|
||||
|
30
tests/backend/model_manager/model_loading/test_model_load.py
Normal file
30
tests/backend/model_manager/model_loading/test_model_load.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""
|
||||
Test model loading
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.model_manager import ModelManagerServiceBase
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
|
||||
|
||||
def test_loading(mm2_model_manager: ModelManagerServiceBase, embedding_file: Path):
|
||||
store = mm2_model_manager.store
|
||||
matches = store.search_by_attr(model_name="test_embedding")
|
||||
assert len(matches) == 0
|
||||
key = mm2_model_manager.install.register_path(embedding_file)
|
||||
loaded_model = mm2_model_manager.load_model_by_config(store.get_model(key))
|
||||
assert loaded_model is not None
|
||||
assert loaded_model.config.key == key
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
||||
loaded_model_2 = mm2_model_manager.load_model_by_key(key)
|
||||
assert loaded_model.config.key == loaded_model_2.config.key
|
||||
|
||||
loaded_model_3 = mm2_model_manager.load_model_by_attr(
|
||||
model_name=loaded_model.config.name,
|
||||
model_type=loaded_model.config.type,
|
||||
base_model=loaded_model.config.base,
|
||||
)
|
||||
assert loaded_model.config.key == loaded_model_3.config.key
|
@ -2,27 +2,32 @@
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from pytest import FixtureRequest
|
||||
from requests.sessions import Session
|
||||
from requests_testadapter import TestAdapter, TestSession
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadQueueService
|
||||
from invokeai.app.services.download import DownloadQueueService, DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
|
||||
from invokeai.app.services.model_records import ModelRecordServiceSQL
|
||||
from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase
|
||||
from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.backend.model_manager_2.model_metadata.metadata_examples import (
|
||||
from tests.backend.model_manager.model_metadata.metadata_examples import (
|
||||
RepoCivitaiModelMetadata1,
|
||||
RepoCivitaiVersionMetadata1,
|
||||
RepoHFMetadata1,
|
||||
@ -85,15 +90,77 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
|
||||
app_config = InvokeAIAppConfig(
|
||||
root=mm2_root_dir,
|
||||
models_dir=mm2_root_dir / "models",
|
||||
log_level="info",
|
||||
)
|
||||
return app_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL:
|
||||
def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> DownloadQueueServiceBase:
|
||||
download_queue = DownloadQueueService(requests_session=mm2_session)
|
||||
download_queue.start()
|
||||
|
||||
def stop_queue() -> None:
|
||||
download_queue.stop()
|
||||
|
||||
request.addfinalizer(stop_queue)
|
||||
return download_queue
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase:
|
||||
return mm2_record_store.metadata_store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
|
||||
ram_cache = ModelCache(
|
||||
logger=InvokeAILogger.get_logger(),
|
||||
max_cache_size=mm2_app_config.ram_cache_size,
|
||||
max_vram_cache_size=mm2_app_config.vram_cache_size,
|
||||
)
|
||||
convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path)
|
||||
return ModelLoadService(
|
||||
app_config=mm2_app_config,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_installer(
|
||||
mm2_app_config: InvokeAIAppConfig,
|
||||
mm2_download_queue: DownloadQueueServiceBase,
|
||||
mm2_session: Session,
|
||||
request: FixtureRequest,
|
||||
) -> ModelInstallServiceBase:
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||
events = DummyEventService()
|
||||
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
|
||||
installer = ModelInstallService(
|
||||
app_config=mm2_app_config,
|
||||
record_store=store,
|
||||
download_queue=mm2_download_queue,
|
||||
event_bus=events,
|
||||
session=mm2_session,
|
||||
)
|
||||
installer.start()
|
||||
|
||||
def stop_installer() -> None:
|
||||
installer.stop()
|
||||
time.sleep(0.1) # avoid error message from the logger when it is closed before thread prints final message
|
||||
|
||||
request.addfinalizer(stop_installer)
|
||||
return installer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
||||
logger = InvokeAILogger.get_logger(config=mm2_app_config)
|
||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||
store = ModelRecordServiceSQL(db)
|
||||
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
# add five simple config records to the database
|
||||
raw1 = {
|
||||
"path": "/tmp/foo1",
|
||||
@ -152,15 +219,16 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStore:
|
||||
db = mm2_record_store._db # to ensure we are sharing the same database
|
||||
return ModelMetadataStore(db)
|
||||
def mm2_model_manager(
|
||||
mm2_record_store: ModelRecordServiceBase, mm2_installer: ModelInstallServiceBase, mm2_loader: ModelLoadServiceBase
|
||||
) -> ModelManagerServiceBase:
|
||||
return ModelManagerService(store=mm2_record_store, install=mm2_installer, load=mm2_loader)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
||||
"""This fixtures defines a series of mock URLs for testing download and installation."""
|
||||
sess = TestSession()
|
||||
sess: Session = TestSession()
|
||||
sess.mount(
|
||||
"https://test.com/missing_model.safetensors",
|
||||
TestAdapter(
|
||||
@ -240,26 +308,3 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
||||
),
|
||||
)
|
||||
return sess
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> ModelInstallServiceBase:
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||
events = DummyEventService()
|
||||
store = ModelRecordServiceSQL(db)
|
||||
metadata_store = ModelMetadataStore(db)
|
||||
|
||||
download_queue = DownloadQueueService(requests_session=mm2_session)
|
||||
download_queue.start()
|
||||
|
||||
installer = ModelInstallService(
|
||||
app_config=mm2_app_config,
|
||||
record_store=store,
|
||||
download_queue=download_queue,
|
||||
metadata_store=metadata_store,
|
||||
event_bus=events,
|
||||
session=mm2_session,
|
||||
)
|
||||
installer.start()
|
||||
return installer
|
File diff suppressed because one or more lines are too long
@ -1,6 +1,7 @@
|
||||
"""
|
||||
Test model metadata fetching and storage.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
|
||||
@ -8,6 +9,7 @@ import pytest
|
||||
from pydantic.networks import HttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreBase
|
||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
CivitaiMetadata,
|
||||
@ -15,14 +17,13 @@ from invokeai.backend.model_manager.metadata import (
|
||||
CommercialUsage,
|
||||
HuggingFaceMetadata,
|
||||
HuggingFaceMetadataFetch,
|
||||
ModelMetadataStore,
|
||||
UnknownMetadataException,
|
||||
)
|
||||
from invokeai.backend.model_manager.util import select_hf_files
|
||||
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
|
||||
|
||||
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None:
|
||||
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStoreBase) -> None:
|
||||
tags = {"text-to-image", "diffusers"}
|
||||
input_metadata = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
@ -40,7 +41,7 @@ def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None:
|
||||
assert mm2_metadata_store.list_tags() == tags
|
||||
|
||||
|
||||
def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None:
|
||||
def test_metadata_store_update(mm2_metadata_store: ModelMetadataStoreBase) -> None:
|
||||
input_metadata = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
author="stabilityai",
|
||||
@ -57,7 +58,7 @@ def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None:
|
||||
assert input_metadata == output_metadata
|
||||
|
||||
|
||||
def test_metadata_search(mm2_metadata_store: ModelMetadataStore) -> None:
|
||||
def test_metadata_search(mm2_metadata_store: ModelMetadataStoreBase) -> None:
|
||||
metadata1 = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
author="stabilityai",
|
||||
@ -133,7 +134,7 @@ def test_metadata_civitai_fetch(mm2_session: Session) -> None:
|
||||
assert metadata.id == 215485
|
||||
assert metadata.author == "test_author" # note that this is not the same as the original from Civitai
|
||||
assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely
|
||||
assert metadata.restrictions.AllowCommercialUse == CommercialUsage("RentCivit")
|
||||
assert CommercialUsage("RentCivit") in metadata.restrictions.AllowCommercialUse
|
||||
assert metadata.version_id == 242807
|
||||
assert metadata.tags == {"tool", "turbo", "sdxl turbo"}
|
||||
|
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2
|
||||
from invokeai.backend.model_manager.util.libc_util import LibcUtil, Struct_mallinfo2
|
||||
|
||||
|
||||
def test_libc_util_mallinfo2():
|
@ -5,8 +5,8 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_management.lora import ModelPatcher
|
||||
from invokeai.backend.model_management.models.lora import LoRALayer, LoRAModelRaw
|
||||
from invokeai.backend.lora import LoRALayer, LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from invokeai.backend.model_management.libc_util import Struct_mallinfo2
|
||||
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.model_manager.util.libc_util import Struct_mallinfo2
|
||||
|
||||
|
||||
def test_memory_snapshot_capture():
|
||||
@ -26,6 +26,7 @@ snapshots = [
|
||||
def test_get_pretty_snapshot_diff(snapshot_1, snapshot_2):
|
||||
"""Test that get_pretty_snapshot_diff() works with various combinations of missing MemorySnapshot fields."""
|
||||
msg = get_pretty_snapshot_diff(snapshot_1, snapshot_2)
|
||||
print(msg)
|
||||
|
||||
expected_lines = 0
|
||||
if snapshot_1 is not None and snapshot_2 is not None:
|
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_management.model_load_optimizations import _no_op, skip_torch_weight_init
|
||||
from invokeai.backend.model_manager.load.optimizations import _no_op, skip_torch_weight_init
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
@ -192,6 +192,7 @@ def sdxl_base_files() -> List[Path]:
|
||||
"text_encoder/model.onnx",
|
||||
"text_encoder_2/config.json",
|
||||
"text_encoder_2/model.onnx",
|
||||
"text_encoder_2/model.onnx_data",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
@ -202,6 +203,7 @@ def sdxl_base_files() -> List[Path]:
|
||||
"tokenizer_2/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/model.onnx",
|
||||
"unet/model.onnx_data",
|
||||
"vae_decoder/config.json",
|
||||
"vae_decoder/model.onnx",
|
||||
"vae_encoder/config.json",
|
@ -1,6 +1,7 @@
|
||||
"""
|
||||
Test interaction of logging with configuration system.
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import re
|
||||
|
@ -4,4 +4,67 @@
|
||||
|
||||
# We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not
|
||||
# play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures.
|
||||
from invokeai.backend.util.test_utils import model_installer, torch_device # noqa: F401
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
|
||||
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
|
||||
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.images.images_default import ImageService
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database # noqa: F401
|
||||
from tests.test_nodes import TestEventService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_services() -> InvocationServices:
|
||||
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db = create_mock_sqlite_database(configuration, logger)
|
||||
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
return InvocationServices(
|
||||
board_image_records=SqliteBoardImageRecordStorage(db=db),
|
||||
board_images=None, # type: ignore
|
||||
board_records=SqliteBoardRecordStorage(db=db),
|
||||
boards=None, # type: ignore
|
||||
bulk_download=BulkDownloadService(),
|
||||
configuration=configuration,
|
||||
events=TestEventService(),
|
||||
image_files=None, # type: ignore
|
||||
image_records=None, # type: ignore
|
||||
images=ImageService(),
|
||||
invocation_cache=MemoryInvocationCache(max_cache_size=0),
|
||||
logger=logging, # type: ignore
|
||||
model_manager=None, # type: ignore
|
||||
download_queue=None, # type: ignore
|
||||
names=None, # type: ignore
|
||||
performance_statistics=InvocationStatsService(),
|
||||
session_processor=None, # type: ignore
|
||||
session_queue=None, # type: ignore
|
||||
urls=None, # type: ignore
|
||||
workflow_records=None, # type: ignore
|
||||
tensors=None, # type: ignore
|
||||
conditioning=None, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_invoker(mock_services: InvocationServices) -> Invoker:
|
||||
return Invoker(services=mock_services)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def invokeai_root_dir(tmp_path_factory) -> Path:
|
||||
root_template = Path(__file__).parent.resolve() / "backend/model_manager/data/invokeai_root"
|
||||
temp_dir: Path = tmp_path_factory.mktemp("data") / "invokeai_root"
|
||||
shutil.copytree(root_template, temp_dir)
|
||||
return temp_dir
|
||||
|
@ -1,27 +1,18 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
||||
|
||||
# This import must happen before other invoke imports or test in other files(!!) break
|
||||
from .test_nodes import ( # isort: split
|
||||
PromptCollectionTestInvocation,
|
||||
PromptTestInvocation,
|
||||
TestEventService,
|
||||
TextToImageTestInvocation,
|
||||
)
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from invokeai.app.invocations.collections import RangeInvocation
|
||||
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor
|
||||
from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
|
||||
from invokeai.app.services.shared.graph import (
|
||||
CollectInvocation,
|
||||
Graph,
|
||||
@ -29,11 +20,11 @@ from invokeai.app.services.shared.graph import (
|
||||
IterateInvocation,
|
||||
)
|
||||
|
||||
from .test_invoker import create_edge
|
||||
from .test_nodes import create_edge
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_graph():
|
||||
def simple_graph() -> Graph:
|
||||
g = Graph()
|
||||
g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
|
||||
g.add_node(TextToImageTestInvocation(id="2"))
|
||||
@ -41,69 +32,23 @@ def simple_graph():
|
||||
return g
|
||||
|
||||
|
||||
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
|
||||
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
|
||||
# the test invocations.
|
||||
@pytest.fixture
|
||||
def mock_services() -> InvocationServices:
|
||||
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
graph_execution_manager = ItemStorageMemory[GraphExecutionState]()
|
||||
return InvocationServices(
|
||||
board_image_records=None, # type: ignore
|
||||
board_images=None, # type: ignore
|
||||
board_records=None, # type: ignore
|
||||
boards=None, # type: ignore
|
||||
configuration=configuration,
|
||||
events=TestEventService(),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
image_files=None, # type: ignore
|
||||
image_records=None, # type: ignore
|
||||
images=None, # type: ignore
|
||||
invocation_cache=MemoryInvocationCache(max_cache_size=0),
|
||||
latents=None, # type: ignore
|
||||
logger=logging, # type: ignore
|
||||
model_manager=None, # type: ignore
|
||||
model_records=None, # type: ignore
|
||||
download_queue=None, # type: ignore
|
||||
model_install=None, # type: ignore
|
||||
names=None, # type: ignore
|
||||
performance_statistics=InvocationStatsService(),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
queue=MemoryInvocationQueue(),
|
||||
session_processor=None, # type: ignore
|
||||
session_queue=None, # type: ignore
|
||||
urls=None, # type: ignore
|
||||
workflow_records=None, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
|
||||
def invoke_next(g: GraphExecutionState) -> tuple[Optional[BaseInvocation], Optional[BaseInvocationOutput]]:
|
||||
n = g.next()
|
||||
if n is None:
|
||||
return (None, None)
|
||||
|
||||
print(f"invoking {n.id}: {type(n)}")
|
||||
o = n.invoke(
|
||||
InvocationContext(
|
||||
queue_batch_id="1",
|
||||
queue_item_id=1,
|
||||
queue_id=DEFAULT_QUEUE_ID,
|
||||
services=services,
|
||||
graph_execution_state_id="1",
|
||||
workflow=None,
|
||||
)
|
||||
)
|
||||
o = n.invoke(Mock(InvocationContext))
|
||||
g.complete(n.id, o)
|
||||
|
||||
return (n, o)
|
||||
|
||||
|
||||
def test_graph_state_executes_in_order(simple_graph, mock_services):
|
||||
def test_graph_state_executes_in_order(simple_graph: Graph):
|
||||
g = GraphExecutionState(graph=simple_graph)
|
||||
|
||||
n1 = invoke_next(g, mock_services)
|
||||
n2 = invoke_next(g, mock_services)
|
||||
n1 = invoke_next(g)
|
||||
n2 = invoke_next(g)
|
||||
n3 = g.next()
|
||||
|
||||
assert g.prepared_source_mapping[n1[0].id] == "1"
|
||||
@ -113,18 +58,18 @@ def test_graph_state_executes_in_order(simple_graph, mock_services):
|
||||
assert n2[0].prompt == n1[0].prompt
|
||||
|
||||
|
||||
def test_graph_is_complete(simple_graph, mock_services):
|
||||
def test_graph_is_complete(simple_graph: Graph):
|
||||
g = GraphExecutionState(graph=simple_graph)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g)
|
||||
_ = invoke_next(g)
|
||||
_ = g.next()
|
||||
|
||||
assert g.is_complete()
|
||||
|
||||
|
||||
def test_graph_is_not_complete(simple_graph, mock_services):
|
||||
def test_graph_is_not_complete(simple_graph: Graph):
|
||||
g = GraphExecutionState(graph=simple_graph)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g)
|
||||
_ = g.next()
|
||||
|
||||
assert not g.is_complete()
|
||||
@ -133,7 +78,7 @@ def test_graph_is_not_complete(simple_graph, mock_services):
|
||||
# TODO: test completion with iterators/subgraphs
|
||||
|
||||
|
||||
def test_graph_state_expands_iterator(mock_services):
|
||||
def test_graph_state_expands_iterator():
|
||||
graph = Graph()
|
||||
graph.add_node(RangeInvocation(id="0", start=0, stop=3, step=1))
|
||||
graph.add_node(IterateInvocation(id="1"))
|
||||
@ -145,7 +90,7 @@ def test_graph_state_expands_iterator(mock_services):
|
||||
|
||||
g = GraphExecutionState(graph=graph)
|
||||
while not g.is_complete():
|
||||
invoke_next(g, mock_services)
|
||||
invoke_next(g)
|
||||
|
||||
prepared_add_nodes = g.source_prepared_mapping["3"]
|
||||
results = {g.results[n].value for n in prepared_add_nodes}
|
||||
@ -153,7 +98,7 @@ def test_graph_state_expands_iterator(mock_services):
|
||||
assert results == expected
|
||||
|
||||
|
||||
def test_graph_state_collects(mock_services):
|
||||
def test_graph_state_collects():
|
||||
graph = Graph()
|
||||
test_prompts = ["Banana sushi", "Cat sushi"]
|
||||
graph.add_node(PromptCollectionTestInvocation(id="1", collection=list(test_prompts)))
|
||||
@ -165,19 +110,19 @@ def test_graph_state_collects(mock_services):
|
||||
graph.add_edge(create_edge("3", "prompt", "4", "item"))
|
||||
|
||||
g = GraphExecutionState(graph=graph)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
n6 = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g)
|
||||
_ = invoke_next(g)
|
||||
_ = invoke_next(g)
|
||||
_ = invoke_next(g)
|
||||
_ = invoke_next(g)
|
||||
n6 = invoke_next(g)
|
||||
|
||||
assert isinstance(n6[0], CollectInvocation)
|
||||
|
||||
assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts)
|
||||
|
||||
|
||||
def test_graph_state_prepares_eagerly(mock_services):
|
||||
def test_graph_state_prepares_eagerly():
|
||||
"""Tests that all prepareable nodes are prepared"""
|
||||
graph = Graph()
|
||||
|
||||
@ -206,7 +151,7 @@ def test_graph_state_prepares_eagerly(mock_services):
|
||||
assert "prompt_iterated" not in g.source_prepared_mapping
|
||||
|
||||
|
||||
def test_graph_executes_depth_first(mock_services):
|
||||
def test_graph_executes_depth_first():
|
||||
"""Tests that the graph executes depth-first, executing a branch as far as possible before moving to the next branch"""
|
||||
graph = Graph()
|
||||
|
||||
@ -220,14 +165,14 @@ def test_graph_executes_depth_first(mock_services):
|
||||
graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt"))
|
||||
|
||||
g = GraphExecutionState(graph=graph)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g)
|
||||
_ = invoke_next(g)
|
||||
_ = invoke_next(g)
|
||||
_ = invoke_next(g)
|
||||
|
||||
# Because ordering is not guaranteed, we cannot compare results directly.
|
||||
# Instead, we must count the number of results.
|
||||
def get_completed_count(g, id):
|
||||
def get_completed_count(g: GraphExecutionState, id: str):
|
||||
ids = list(g.source_prepared_mapping[id])
|
||||
completed_ids = [i for i in g.executed if i in ids]
|
||||
return len(completed_ids)
|
||||
@ -236,17 +181,17 @@ def test_graph_executes_depth_first(mock_services):
|
||||
assert get_completed_count(g, "prompt_iterated") == 1
|
||||
assert get_completed_count(g, "prompt_successor") == 0
|
||||
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g)
|
||||
|
||||
assert get_completed_count(g, "prompt_iterated") == 1
|
||||
assert get_completed_count(g, "prompt_successor") == 1
|
||||
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g)
|
||||
|
||||
assert get_completed_count(g, "prompt_iterated") == 2
|
||||
assert get_completed_count(g, "prompt_successor") == 1
|
||||
|
||||
_ = invoke_next(g, mock_services)
|
||||
_ = invoke_next(g)
|
||||
|
||||
assert get_completed_count(g, "prompt_iterated") == 2
|
||||
assert get_completed_count(g, "prompt_successor") == 2
|
@ -1,47 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.backend import BaseModelType, ModelManager, ModelType, SubModelType
|
||||
|
||||
BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main)
|
||||
VAE_OVERRIDE_MODEL_NAME = ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main)
|
||||
VAE_NULL_OVERRIDE_MODEL_NAME = ("SDXL with empty VAE", BaseModelType.StableDiffusionXL, ModelType.Main)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_manager(datadir) -> ModelManager:
|
||||
InvokeAIAppConfig.get_config(root=datadir)
|
||||
return ModelManager(datadir / "configs" / "relative_sub.models.yaml")
|
||||
|
||||
|
||||
def test_get_model_names(model_manager: ModelManager):
|
||||
names = model_manager.model_names()
|
||||
assert names[:2] == [BASIC_MODEL_NAME, VAE_OVERRIDE_MODEL_NAME]
|
||||
|
||||
|
||||
def test_get_model_path_for_diffusers(model_manager: ModelManager, datadir: Path):
|
||||
model_config = model_manager._get_model_config(BASIC_MODEL_NAME[1], BASIC_MODEL_NAME[0], BASIC_MODEL_NAME[2])
|
||||
top_model_path, is_override = model_manager._get_model_path(model_config)
|
||||
expected_model_path = datadir / "models" / "sdxl" / "main" / "SDXL base 1_0"
|
||||
assert top_model_path == expected_model_path
|
||||
assert not is_override
|
||||
|
||||
|
||||
def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir: Path):
|
||||
model_config = model_manager._get_model_config(
|
||||
VAE_OVERRIDE_MODEL_NAME[1], VAE_OVERRIDE_MODEL_NAME[0], VAE_OVERRIDE_MODEL_NAME[2]
|
||||
)
|
||||
vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae)
|
||||
expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix"
|
||||
assert vae_model_path == expected_vae_path
|
||||
assert is_override
|
||||
|
||||
|
||||
def test_get_model_path_for_null_overridden_vae(model_manager: ModelManager, datadir: Path):
|
||||
model_config = model_manager._get_model_config(
|
||||
VAE_NULL_OVERRIDE_MODEL_NAME[1], VAE_NULL_OVERRIDE_MODEL_NAME[0], VAE_NULL_OVERRIDE_MODEL_NAME[2]
|
||||
)
|
||||
vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae)
|
||||
assert not is_override
|
@ -2,8 +2,8 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.backend import BaseModelType
|
||||
from invokeai.backend.model_management.model_probe import VaeFolderProbe
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.probe import VaeFolderProbe
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -20,3 +20,11 @@ def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Pat
|
||||
probe = VaeFolderProbe(sd1_vae_path)
|
||||
base_type = probe.get_base_type()
|
||||
assert base_type == expected_type
|
||||
repo_variant = probe.get_repo_variant()
|
||||
assert repo_variant == ModelRepoVariant.DEFAULT
|
||||
|
||||
|
||||
def test_repo_variant(datadir: Path):
|
||||
probe = VaeFolderProbe(datadir / "vae" / "taesdxl-fp16")
|
||||
repo_variant = probe.get_repo_variant()
|
||||
assert repo_variant == ModelRepoVariant.FP16
|
||||
|
37
tests/test_model_probe/vae/taesdxl-fp16/config.json
Normal file
37
tests/test_model_probe/vae/taesdxl-fp16/config.json
Normal file
@ -0,0 +1,37 @@
|
||||
{
|
||||
"_class_name": "AutoencoderTiny",
|
||||
"_diffusers_version": "0.20.0.dev0",
|
||||
"act_fn": "relu",
|
||||
"decoder_block_out_channels": [
|
||||
64,
|
||||
64,
|
||||
64,
|
||||
64
|
||||
],
|
||||
"encoder_block_out_channels": [
|
||||
64,
|
||||
64,
|
||||
64,
|
||||
64
|
||||
],
|
||||
"force_upcast": false,
|
||||
"in_channels": 3,
|
||||
"latent_channels": 4,
|
||||
"latent_magnitude": 3,
|
||||
"latent_shift": 0.5,
|
||||
"num_decoder_blocks": [
|
||||
3,
|
||||
3,
|
||||
3,
|
||||
1
|
||||
],
|
||||
"num_encoder_blocks": [
|
||||
1,
|
||||
3,
|
||||
3,
|
||||
3
|
||||
],
|
||||
"out_channels": 3,
|
||||
"scaling_factor": 1.0,
|
||||
"upsampling_scaling_factor": 2
|
||||
}
|
@ -8,8 +8,6 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.image import ShowImageInvocation
|
||||
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
||||
from invokeai.app.invocations.primitives import (
|
||||
FloatCollectionInvocation,
|
||||
FloatInvocation,
|
||||
@ -17,13 +15,11 @@ from invokeai.app.invocations.primitives import (
|
||||
StringInvocation,
|
||||
)
|
||||
from invokeai.app.invocations.upscale import ESRGANInvocation
|
||||
from invokeai.app.services.shared.default_graphs import create_text_to_image
|
||||
from invokeai.app.services.shared.graph import (
|
||||
CollectInvocation,
|
||||
Edge,
|
||||
EdgeConnection,
|
||||
Graph,
|
||||
GraphInvocation,
|
||||
InvalidEdgeError,
|
||||
IterateInvocation,
|
||||
NodeAlreadyInGraphError,
|
||||
@ -425,21 +421,6 @@ def test_graph_invalid_if_edges_reference_missing_nodes():
|
||||
assert g.is_valid() is False
|
||||
|
||||
|
||||
def test_graph_invalid_if_subgraph_invalid():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id="1")
|
||||
n1.graph = Graph()
|
||||
|
||||
n1_1 = TextToImageTestInvocation(id="2", prompt="Banana sushi")
|
||||
n1.graph.nodes[n1_1.id] = n1_1
|
||||
e1 = create_edge("1", "image", "2", "image")
|
||||
n1.graph.edges.append(e1)
|
||||
|
||||
g.nodes[n1.id] = n1
|
||||
|
||||
assert g.is_valid() is False
|
||||
|
||||
|
||||
def test_graph_invalid_if_has_cycle():
|
||||
g = Graph()
|
||||
n1 = ESRGANInvocation(id="1")
|
||||
@ -466,110 +447,6 @@ def test_graph_invalid_with_invalid_connection():
|
||||
assert g.is_valid() is False
|
||||
|
||||
|
||||
# TODO: Subgraph operations
|
||||
def test_graph_gets_subgraph_node():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id="1")
|
||||
n1.graph = Graph()
|
||||
|
||||
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
|
||||
n1.graph.add_node(n1_1)
|
||||
|
||||
g.add_node(n1)
|
||||
|
||||
result = g.get_node("1.1")
|
||||
|
||||
assert result is not None
|
||||
assert result.id == "1"
|
||||
assert result == n1_1
|
||||
|
||||
|
||||
def test_graph_expands_subgraph():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id="1")
|
||||
n1.graph = Graph()
|
||||
|
||||
n1_1 = AddInvocation(id="1", a=1, b=2)
|
||||
n1_2 = SubtractInvocation(id="2", b=3)
|
||||
n1.graph.add_node(n1_1)
|
||||
n1.graph.add_node(n1_2)
|
||||
n1.graph.add_edge(create_edge("1", "value", "2", "a"))
|
||||
|
||||
g.add_node(n1)
|
||||
|
||||
n2 = AddInvocation(id="2", b=5)
|
||||
g.add_node(n2)
|
||||
g.add_edge(create_edge("1.2", "value", "2", "a"))
|
||||
|
||||
dg = g.nx_graph_flat()
|
||||
assert set(dg.nodes) == {"1.1", "1.2", "2"}
|
||||
assert set(dg.edges) == {("1.1", "1.2"), ("1.2", "2")}
|
||||
|
||||
|
||||
def test_graph_subgraph_t2i():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id="1")
|
||||
|
||||
# Get text to image default graph
|
||||
lg = create_text_to_image()
|
||||
n1.graph = lg.graph
|
||||
|
||||
g.add_node(n1)
|
||||
|
||||
n2 = IntegerInvocation(id="2", value=512)
|
||||
n3 = IntegerInvocation(id="3", value=256)
|
||||
|
||||
g.add_node(n2)
|
||||
g.add_node(n3)
|
||||
|
||||
g.add_edge(create_edge("2", "value", "1.width", "value"))
|
||||
g.add_edge(create_edge("3", "value", "1.height", "value"))
|
||||
|
||||
n4 = ShowImageInvocation(id="4")
|
||||
g.add_node(n4)
|
||||
g.add_edge(create_edge("1.8", "image", "4", "image"))
|
||||
|
||||
# Validate
|
||||
dg = g.nx_graph_flat()
|
||||
assert set(dg.nodes) == {"1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"}
|
||||
expected_edges = [(f"1.{e.source.node_id}", f"1.{e.destination.node_id}") for e in lg.graph.edges]
|
||||
expected_edges.extend([("2", "1.width"), ("3", "1.height"), ("1.8", "4")])
|
||||
print(expected_edges)
|
||||
print(list(dg.edges))
|
||||
assert set(dg.edges) == set(expected_edges)
|
||||
|
||||
|
||||
def test_graph_fails_to_get_missing_subgraph_node():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id="1")
|
||||
n1.graph = Graph()
|
||||
|
||||
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
|
||||
n1.graph.add_node(n1_1)
|
||||
|
||||
g.add_node(n1)
|
||||
|
||||
with pytest.raises(NodeNotFoundError):
|
||||
_ = g.get_node("1.2")
|
||||
|
||||
|
||||
def test_graph_fails_to_enumerate_non_subgraph_node():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id="1")
|
||||
n1.graph = Graph()
|
||||
|
||||
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
|
||||
n1.graph.add_node(n1_1)
|
||||
|
||||
g.add_node(n1)
|
||||
|
||||
n2 = ESRGANInvocation(id="2")
|
||||
g.add_node(n2)
|
||||
|
||||
with pytest.raises(NodeNotFoundError):
|
||||
_ = g.get_node("2.1")
|
||||
|
||||
|
||||
def test_graph_gets_networkx_graph():
|
||||
g = Graph()
|
||||
n1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
|
@ -1,15 +1,16 @@
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import InputField, OutputField
|
||||
from invokeai.app.invocations.image import ImageField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
# Define test invocations before importing anything that uses invocations
|
||||
@ -116,25 +117,22 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg
|
||||
)
|
||||
|
||||
|
||||
class TestEvent:
|
||||
event_name: str
|
||||
payload: Any
|
||||
class TestEvent(BaseModel):
|
||||
__test__ = False # not a pytest test case
|
||||
|
||||
def __init__(self, event_name: str, payload: Any):
|
||||
self.event_name = event_name
|
||||
self.payload = payload
|
||||
event_name: str
|
||||
payload: Any
|
||||
|
||||
|
||||
class TestEventService(EventServiceBase):
|
||||
events: list
|
||||
__test__ = False # not a pytest test case
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.events = []
|
||||
self.events: list[TestEvent] = []
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
self.events.append(TestEvent(event_name=payload["event"], payload=payload["data"]))
|
||||
pass
|
||||
|
||||
|
172
tests/test_object_serializer_disk.py
Normal file
172
tests/test_object_serializer_disk.py
Normal file
@ -0,0 +1,172 @@
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
|
||||
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
|
||||
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockDataclass:
|
||||
foo: str
|
||||
|
||||
|
||||
def count_files(path: Path):
|
||||
return len(list(path.iterdir()))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def obj_serializer(tmp_path: Path):
|
||||
return ObjectSerializerDisk[MockDataclass](tmp_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fwd_cache(tmp_path: Path):
|
||||
return ObjectSerializerForwardCache(ObjectSerializerDisk[MockDataclass](tmp_path), max_cache_size=2)
|
||||
|
||||
|
||||
def test_obj_serializer_disk_initializes(tmp_path: Path):
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path)
|
||||
assert obj_serializer._output_dir == tmp_path
|
||||
|
||||
|
||||
def test_obj_serializer_disk_saves(obj_serializer: ObjectSerializerDisk[MockDataclass]):
|
||||
obj_1 = MockDataclass(foo="bar")
|
||||
obj_1_name = obj_serializer.save(obj_1)
|
||||
assert Path(obj_serializer._output_dir, obj_1_name).exists()
|
||||
|
||||
obj_2 = MockDataclass(foo="baz")
|
||||
obj_2_name = obj_serializer.save(obj_2)
|
||||
assert Path(obj_serializer._output_dir, obj_2_name).exists()
|
||||
|
||||
|
||||
def test_obj_serializer_disk_loads(obj_serializer: ObjectSerializerDisk[MockDataclass]):
|
||||
obj_1 = MockDataclass(foo="bar")
|
||||
obj_1_name = obj_serializer.save(obj_1)
|
||||
assert obj_serializer.load(obj_1_name).foo == "bar"
|
||||
|
||||
obj_2 = MockDataclass(foo="baz")
|
||||
obj_2_name = obj_serializer.save(obj_2)
|
||||
assert obj_serializer.load(obj_2_name).foo == "baz"
|
||||
|
||||
with pytest.raises(ObjectNotFoundError):
|
||||
obj_serializer.load("nonexistent_object_name")
|
||||
|
||||
|
||||
def test_obj_serializer_disk_deletes(obj_serializer: ObjectSerializerDisk[MockDataclass]):
|
||||
obj_1 = MockDataclass(foo="bar")
|
||||
obj_1_name = obj_serializer.save(obj_1)
|
||||
|
||||
obj_2 = MockDataclass(foo="bar")
|
||||
obj_2_name = obj_serializer.save(obj_2)
|
||||
|
||||
obj_serializer.delete(obj_1_name)
|
||||
assert not Path(obj_serializer._output_dir, obj_1_name).exists()
|
||||
assert Path(obj_serializer._output_dir, obj_2_name).exists()
|
||||
|
||||
|
||||
def test_obj_serializer_ephemeral_creates_tempdir(tmp_path: Path):
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
|
||||
assert isinstance(obj_serializer._tempdir, tempfile.TemporaryDirectory)
|
||||
assert obj_serializer._base_output_dir == tmp_path
|
||||
assert obj_serializer._output_dir != tmp_path
|
||||
assert obj_serializer._output_dir == Path(obj_serializer._tempdir.name)
|
||||
|
||||
|
||||
def test_obj_serializer_ephemeral_deletes_tempdir(tmp_path: Path):
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
|
||||
tempdir_path = obj_serializer._output_dir
|
||||
del obj_serializer
|
||||
assert not tempdir_path.exists()
|
||||
|
||||
|
||||
def test_obj_serializer_ephemeral_deletes_tempdir_on_stop(tmp_path: Path):
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
|
||||
tempdir_path = obj_serializer._output_dir
|
||||
obj_serializer.stop(None) # pyright: ignore [reportArgumentType]
|
||||
assert not tempdir_path.exists()
|
||||
|
||||
|
||||
def test_obj_serializer_ephemeral_writes_to_tempdir(tmp_path: Path):
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
|
||||
obj_1 = MockDataclass(foo="bar")
|
||||
obj_1_name = obj_serializer.save(obj_1)
|
||||
assert Path(obj_serializer._output_dir, obj_1_name).exists()
|
||||
assert not Path(tmp_path, obj_1_name).exists()
|
||||
|
||||
|
||||
def test_obj_serializer_disk_different_types(tmp_path: Path):
|
||||
obj_serializer_1 = ObjectSerializerDisk[MockDataclass](tmp_path)
|
||||
obj_1 = MockDataclass(foo="bar")
|
||||
obj_1_name = obj_serializer_1.save(obj_1)
|
||||
obj_1_loaded = obj_serializer_1.load(obj_1_name)
|
||||
assert obj_serializer_1._obj_class_name == "MockDataclass"
|
||||
assert isinstance(obj_1_loaded, MockDataclass)
|
||||
assert obj_1_loaded.foo == "bar"
|
||||
assert obj_1_name.startswith("MockDataclass_")
|
||||
|
||||
obj_serializer_2 = ObjectSerializerDisk[int](tmp_path)
|
||||
obj_2_name = obj_serializer_2.save(9001)
|
||||
assert obj_serializer_2._obj_class_name == "int"
|
||||
assert obj_serializer_2.load(obj_2_name) == 9001
|
||||
assert obj_2_name.startswith("int_")
|
||||
|
||||
obj_serializer_3 = ObjectSerializerDisk[str](tmp_path)
|
||||
obj_3_name = obj_serializer_3.save("foo")
|
||||
assert obj_serializer_3._obj_class_name == "str"
|
||||
assert obj_serializer_3.load(obj_3_name) == "foo"
|
||||
assert obj_3_name.startswith("str_")
|
||||
|
||||
obj_serializer_4 = ObjectSerializerDisk[torch.Tensor](tmp_path)
|
||||
obj_4_name = obj_serializer_4.save(torch.tensor([1, 2, 3]))
|
||||
obj_4_loaded = obj_serializer_4.load(obj_4_name)
|
||||
assert obj_serializer_4._obj_class_name == "Tensor"
|
||||
assert isinstance(obj_4_loaded, torch.Tensor)
|
||||
assert torch.equal(obj_4_loaded, torch.tensor([1, 2, 3]))
|
||||
assert obj_4_name.startswith("Tensor_")
|
||||
|
||||
|
||||
def test_obj_serializer_fwd_cache_initializes(obj_serializer: ObjectSerializerDisk[MockDataclass]):
|
||||
fwd_cache = ObjectSerializerForwardCache(obj_serializer)
|
||||
assert fwd_cache._underlying_storage == obj_serializer
|
||||
|
||||
|
||||
def test_obj_serializer_fwd_cache_saves_and_loads(fwd_cache: ObjectSerializerForwardCache[MockDataclass]):
|
||||
obj = MockDataclass(foo="bar")
|
||||
obj_name = fwd_cache.save(obj)
|
||||
obj_loaded = fwd_cache.load(obj_name)
|
||||
obj_underlying = fwd_cache._underlying_storage.load(obj_name)
|
||||
assert obj_loaded == obj_underlying
|
||||
assert obj_loaded.foo == "bar"
|
||||
|
||||
|
||||
def test_obj_serializer_fwd_cache_respects_cache_size(fwd_cache: ObjectSerializerForwardCache[MockDataclass]):
|
||||
obj_1 = MockDataclass(foo="bar")
|
||||
obj_1_name = fwd_cache.save(obj_1)
|
||||
obj_2 = MockDataclass(foo="baz")
|
||||
obj_2_name = fwd_cache.save(obj_2)
|
||||
obj_3 = MockDataclass(foo="qux")
|
||||
obj_3_name = fwd_cache.save(obj_3)
|
||||
assert obj_1_name not in fwd_cache._cache
|
||||
assert obj_2_name in fwd_cache._cache
|
||||
assert obj_3_name in fwd_cache._cache
|
||||
# apparently qsize is "not reliable"?
|
||||
assert fwd_cache._cache_ids.qsize() == 2
|
||||
|
||||
|
||||
def test_obj_serializer_fwd_cache_calls_delete_callback(fwd_cache: ObjectSerializerForwardCache[MockDataclass]):
|
||||
called_name = None
|
||||
obj_1 = MockDataclass(foo="bar")
|
||||
|
||||
def on_deleted(name: str):
|
||||
nonlocal called_name
|
||||
called_name = name
|
||||
|
||||
fwd_cache.on_deleted(on_deleted)
|
||||
obj_1_name = fwd_cache.save(obj_1)
|
||||
fwd_cache.delete(obj_1_name)
|
||||
assert called_name == obj_1_name
|
@ -2,6 +2,7 @@
|
||||
Not really a test, but a way to verify that the paths are existing
|
||||
and fail early if they are not.
|
||||
"""
|
||||
|
||||
import pathlib
|
||||
import unittest
|
||||
from os import path as osp
|
||||
|
@ -8,11 +8,11 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
NodeFieldValue,
|
||||
calc_session_count,
|
||||
create_session_nfv_tuples,
|
||||
populate_graph,
|
||||
prepare_values_to_insert,
|
||||
)
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation
|
||||
from tests.aa_nodes.test_nodes import PromptTestInvocation
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState
|
||||
|
||||
from .test_nodes import PromptTestInvocation
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -39,30 +39,6 @@ def batch_graph() -> Graph:
|
||||
return g
|
||||
|
||||
|
||||
def test_populate_graph_with_subgraph():
|
||||
g1 = Graph()
|
||||
g1.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
|
||||
g1.add_node(PromptTestInvocation(id="2", prompt="Banana sushi"))
|
||||
n1 = PromptTestInvocation(id="1", prompt="Banana snake")
|
||||
subgraph = Graph()
|
||||
subgraph.add_node(n1)
|
||||
g1.add_node(GraphInvocation(id="3", graph=subgraph))
|
||||
|
||||
nfvs = [
|
||||
NodeFieldValue(node_path="1", field_name="prompt", value="Strawberry sushi"),
|
||||
NodeFieldValue(node_path="2", field_name="prompt", value="Strawberry sunday"),
|
||||
NodeFieldValue(node_path="3.1", field_name="prompt", value="Strawberry snake"),
|
||||
]
|
||||
|
||||
g2 = populate_graph(g1, nfvs)
|
||||
|
||||
# do not mutate g1
|
||||
assert g1 is not g2
|
||||
assert g2.get_node("1").prompt == "Strawberry sushi"
|
||||
assert g2.get_node("2").prompt == "Strawberry sunday"
|
||||
assert g2.get_node("3.1").prompt == "Strawberry snake"
|
||||
|
||||
|
||||
def test_create_sessions_from_batch_with_runs(batch_data_collection, batch_graph):
|
||||
b = Batch(graph=batch_graph, data=batch_data_collection, runs=2)
|
||||
t = list(create_session_nfv_tuples(batch=b, maximum=1000))
|
Reference in New Issue
Block a user