InvokeAI/tests/test_config.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

296 lines
10 KiB
Python
Raw Normal View History

from pathlib import Path
2024-03-11 13:24:30 +00:00
from tempfile import TemporaryDirectory
2023-08-17 22:45:25 +00:00
from typing import Any
2023-05-04 04:45:52 +00:00
2023-08-17 22:45:25 +00:00
import pytest
2023-05-04 04:45:52 +00:00
from omegaconf import OmegaConf
from pydantic import ValidationError
2023-05-04 04:45:52 +00:00
from invokeai.app.services.config.config_default import (
DefaultInvokeAIAppConfig,
InvokeAIAppConfig,
get_config,
load_and_migrate_config,
)
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
2023-08-20 19:27:51 +00:00
2024-03-11 13:24:30 +00:00
v4_config = """
schema_version: 4.0.0
2023-08-17 22:45:25 +00:00
2024-03-11 13:24:30 +00:00
host: "192.168.1.1"
port: 8080
"""
2023-05-25 03:57:15 +00:00
2024-03-11 13:24:30 +00:00
invalid_v5_config = """
schema_version: 5.0.0
2024-03-11 13:24:30 +00:00
host: "192.168.1.1"
port: 8080
2023-07-27 14:54:01 +00:00
"""
2023-05-04 04:45:52 +00:00
2024-03-11 13:24:30 +00:00
v3_config = """
InvokeAI:
2024-03-11 13:24:30 +00:00
Web Server:
host: 192.168.1.1
port: 8080
Features:
2024-03-11 13:24:30 +00:00
esrgan: true
internet_available: true
log_tokenization: false
patchmatch: true
ignore_missing_core_models: false
Paths:
outdir: /some/outputs/dir
conf_path: /custom/models.yaml
Model Cache:
max_cache_size: 100
max_vram_cache_size: 50
2023-07-27 14:54:01 +00:00
"""
2023-05-04 04:45:52 +00:00
v3_config_with_bad_values = """
InvokeAI:
Web Server:
port: "ice cream"
"""
2024-03-11 13:24:30 +00:00
invalid_config = """
i like turtles
2023-08-17 17:47:26 +00:00
"""
2023-07-27 14:54:01 +00:00
2024-03-11 13:24:30 +00:00
@pytest.fixture
def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None:
"""This may be overkill since the current tests don't need the root dir to exist"""
monkeypatch.setenv("INVOKEAI_ROOT", str(tmp_path))
def test_path_resolution_root_not_set(patch_rootdir: None):
2024-03-11 13:24:30 +00:00
"""Test path resolutions when the root is not explicitly set."""
config = InvokeAIAppConfig()
expected_root = InvokeAIAppConfig.find_root()
assert config.root_path == expected_root
2023-05-04 04:45:52 +00:00
def test_read_config_from_file(tmp_path: Path, patch_rootdir: None):
2024-03-11 13:24:30 +00:00
"""Test reading configuration from a file."""
temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(v4_config)
2023-07-27 14:54:01 +00:00
2024-03-11 13:24:30 +00:00
config = load_and_migrate_config(temp_config_file)
assert config.host == "192.168.1.1"
assert config.port == 8080
2023-07-27 14:54:01 +00:00
2023-09-11 13:57:41 +00:00
def test_migrate_v3_config_from_file(tmp_path: Path, patch_rootdir: None):
2024-03-11 13:24:30 +00:00
"""Test reading configuration from a file."""
temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(v3_config)
2023-08-17 17:47:26 +00:00
2024-03-11 13:24:30 +00:00
config = load_and_migrate_config(temp_config_file)
assert config.outputs_dir == Path("/some/outputs/dir")
assert config.host == "192.168.1.1"
assert config.port == 8080
assert config.ram == 100
assert config.legacy_models_yaml_path == Path("/custom/models.yaml")
2024-03-11 13:24:30 +00:00
# This should be stripped out
assert not hasattr(config, "esrgan")
2023-08-17 17:47:26 +00:00
2023-09-11 13:57:41 +00:00
def test_migrate_v3_backup(tmp_path: Path, patch_rootdir: None):
"""Test the backup of the config file."""
temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(v3_config)
load_and_migrate_config(temp_config_file)
assert temp_config_file.with_suffix(".yaml.bak").exists()
assert temp_config_file.with_suffix(".yaml.bak").read_text() == v3_config
def test_failed_migrate_backup(tmp_path: Path, patch_rootdir: None):
"""Test the failed migration of the config file."""
temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(v3_config_with_bad_values)
with pytest.raises(RuntimeError):
load_and_migrate_config(temp_config_file)
assert temp_config_file.with_suffix(".yaml.bak").exists()
assert temp_config_file.with_suffix(".yaml.bak").read_text() == v3_config_with_bad_values
assert temp_config_file.exists()
assert temp_config_file.read_text() == v3_config_with_bad_values
def test_bails_on_invalid_config(tmp_path: Path, patch_rootdir: None):
2024-03-11 13:24:30 +00:00
"""Test reading configuration from a file."""
temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(invalid_config)
2023-07-27 14:54:01 +00:00
2024-03-11 13:24:30 +00:00
with pytest.raises(AssertionError):
load_and_migrate_config(temp_config_file)
2023-07-27 14:54:01 +00:00
2023-08-17 22:45:25 +00:00
def test_bails_on_config_with_unsupported_version(tmp_path: Path, patch_rootdir: None):
2024-03-11 13:24:30 +00:00
"""Test reading configuration from a file."""
temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(invalid_v5_config)
2023-05-04 04:45:52 +00:00
2024-03-11 13:24:30 +00:00
with pytest.raises(RuntimeError, match="Invalid schema version"):
load_and_migrate_config(temp_config_file)
2023-05-04 04:45:52 +00:00
def test_write_config_to_file(patch_rootdir: None):
2024-03-11 13:24:30 +00:00
"""Test writing configuration to a file, checking for correct output."""
with TemporaryDirectory() as tmpdir:
temp_config_path = Path(tmpdir) / "invokeai.yaml"
config = InvokeAIAppConfig(host="192.168.1.1", port=8080)
config.write_file(temp_config_path)
2024-03-11 13:24:30 +00:00
# Load the file and check contents
with open(temp_config_path, "r") as file:
content = file.read()
# This is a default value, so it should not be in the file
assert "pil_compress_level" not in content
2024-03-11 13:24:30 +00:00
assert "host: 192.168.1.1" in content
assert "port: 8080" in content
2023-09-21 16:43:34 +00:00
2023-05-04 04:45:52 +00:00
def test_update_config_with_dict(patch_rootdir: None):
2024-03-11 13:24:30 +00:00
"""Test updating the config with a dictionary."""
config = InvokeAIAppConfig()
update_dict = {"host": "10.10.10.10", "port": 6060}
config.update_config(update_dict)
assert config.host == "10.10.10.10"
assert config.port == 6060
2023-08-17 22:45:25 +00:00
def test_update_config_with_object(patch_rootdir: None):
2024-03-11 13:24:30 +00:00
"""Test updating the config with another config object."""
config = InvokeAIAppConfig()
new_config = InvokeAIAppConfig(host="10.10.10.10", port=6060)
config.update_config(new_config)
assert config.host == "10.10.10.10"
assert config.port == 6060
def test_set_and_resolve_paths(patch_rootdir: None):
2024-03-11 13:24:30 +00:00
"""Test setting root and resolving paths based on it."""
with TemporaryDirectory() as tmpdir:
config = InvokeAIAppConfig()
config._root = Path(tmpdir)
2024-03-11 22:28:46 +00:00
assert config.models_path == Path(tmpdir).resolve() / "models"
assert config.db_path == Path(tmpdir).resolve() / "databases" / "invokeai.db"
2023-08-17 22:45:25 +00:00
def test_singleton_behavior(patch_rootdir: None):
2024-03-11 13:24:30 +00:00
"""Test that get_config always returns the same instance."""
get_config.cache_clear()
2024-03-11 13:24:30 +00:00
config1 = get_config()
config2 = get_config()
assert config1 is config2
get_config.cache_clear()
def test_default_config(patch_rootdir: None):
"""Test that the default config is as expected."""
config = DefaultInvokeAIAppConfig()
assert config.host == "127.0.0.1"
def test_env_vars(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
"""Test that environment variables are merged into the config"""
monkeypatch.setenv("INVOKEAI_ROOT", str(tmp_path))
monkeypatch.setenv("INVOKEAI_HOST", "1.2.3.4")
monkeypatch.setenv("INVOKEAI_PORT", "1234")
config = InvokeAIAppConfig()
assert config.host == "1.2.3.4"
assert config.port == 1234
assert config.root_path == tmp_path
def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
"""Test that get_config writes the appropriate files to disk"""
# Trick the config into thinking it has already parsed args - this triggers the writing of the config file
InvokeAIArgs.did_parse = True
monkeypatch.setenv("INVOKEAI_ROOT", str(tmp_path))
monkeypatch.setenv("INVOKEAI_HOST", "1.2.3.4")
get_config.cache_clear()
config = get_config()
get_config.cache_clear()
config_file_path = tmp_path / "invokeai.yaml"
example_file_path = config_file_path.with_suffix(".example.yaml")
assert config.config_file_path == config_file_path
assert config_file_path.exists()
assert example_file_path.exists()
# The example file should have the default values
example_file_content = example_file_path.read_text()
assert "host: 127.0.0.1" in example_file_content
assert "port: 9090" in example_file_content
# It should also have the `remote_api_tokens` key
assert "remote_api_tokens" in example_file_content
# Neither env vars nor default values should be written to the config file
config_file_content = config_file_path.read_text()
assert "host" not in config_file_content
# Undo our change to the singleton class
InvokeAIArgs.did_parse = False
@pytest.mark.xfail(
reason="""
This test fails when run as part of the full test suite.
This test needs to deny nodes from being included in the InvocationsUnion by providing
an app configuration as a test fixture. Pytest executes all test files before running
tests, so the app configuration is already initialized by the time this test runs, and
the InvocationUnion is already created and the denied nodes are not omitted from it.
This test passes when `test_config.py` is tested in isolation.
2024-03-11 13:24:30 +00:00
Perhaps a solution would be to call `get_app_config().parse_args()` in
other test files?
"""
)
def test_deny_nodes(patch_rootdir):
# Allow integer, string and float, but explicitly deny float
allow_deny_nodes_conf = OmegaConf.create(
"""
InvokeAI:
Nodes:
allow_nodes:
- integer
- string
- float
deny_nodes:
- float
"""
)
# must parse config before importing Graph, so its nodes union uses the config
get_config.cache_clear()
2024-03-11 13:24:30 +00:00
conf = get_config()
get_config.cache_clear()
conf.merge_from_file(conf=allow_deny_nodes_conf, argv=[])
2023-10-11 00:54:07 +00:00
from invokeai.app.services.shared.graph import Graph
# confirm graph validation fails when using denied node
Graph(nodes={"1": {"id": "1", "type": "integer"}})
Graph(nodes={"1": {"id": "1", "type": "string"}})
with pytest.raises(ValidationError):
Graph(nodes={"1": {"id": "1", "type": "float"}})
from invokeai.app.invocations.baseinvocation import BaseInvocation
# confirm invocations union will not have denied nodes
all_invocations = BaseInvocation.get_invocations()
feat: workflow library (#5148) * chore: bump pydantic to 2.5.2 This release fixes pydantic/pydantic#8175 and allows us to use `JsonValue` * fix(ui): exclude public/en.json from prettier config * fix(workflow_records): fix SQLite workflow insertion to ignore duplicates * feat(backend): update workflows handling Update workflows handling for Workflow Library. **Updated Workflow Storage** "Embedded Workflows" are workflows associated with images, and are now only stored in the image files. "Library Workflows" are not associated with images, and are stored only in DB. This works out nicely. We have always saved workflows to files, but recently began saving them to the DB in addition to in image files. When that happened, we stopped reading workflows from files, so all the workflows that only existed in images were inaccessible. With this change, access to those workflows is restored, and no workflows are lost. **Updated Workflow Handling in Nodes** Prior to this change, workflows were embedded in images by passing the whole workflow JSON to a special workflow field on a node. In the node's `invoke()` function, the node was able to access this workflow and save it with the image. This (inaccurately) models workflows as a property of an image and is rather awkward technically. A workflow is now a property of a batch/session queue item. It is available in the InvocationContext and therefore available to all nodes during `invoke()`. **Database Migrations** Added a `SQLiteMigrator` class to handle database migrations. Migrations were needed to accomodate the DB-related changes in this PR. See the code for details. The `images`, `workflows` and `session_queue` tables required migrations for this PR, and are using the new migrator. Other tables/services are still creating tables themselves. A followup PR will adapt them to use the migrator. **Other/Support Changes** - Add a `has_workflow` column to `images` table to indicate that the image has an embedded workflow. - Add handling for retrieving the workflow from an image in python. The image file must be fetched, the workflow extracted, and then sent to client, avoiding needing the browser to parse the image file. With the `has_workflow` column, the UI knows if there is a workflow to be fetched, and only fetches when the user requests to load the workflow. - Add route to get the workflow from an image - Add CRUD service/routes for the library workflows - `workflow_images` table and services removed (no longer needed now that embedded workflows are not in the DB) * feat(ui): updated workflow handling (WIP) Clientside updates for the backend workflow changes. Includes roughed-out workflow library UI. * feat: revert SQLiteMigrator class Will pursue this in a separate PR. * feat(nodes): do not overwrite custom node module names Use a different, simpler method to detect if a node is custom. * feat(nodes): restore WithWorkflow as no-op class This class is deprecated and no longer needed. Set its workflow attr value to None (meaning it is now a no-op), and issue a warning when an invocation subclasses it. * fix(nodes): fix get_workflow from queue item dict func * feat(backend): add WorkflowRecordListItemDTO This is the id, name, description, created at and updated at workflow columns/attrs. Used to display lists of workflowsl * chore(ui): typegen * feat(ui): add workflow loading, deleting to workflow library UI * feat(ui): workflow library pagination button styles * wip * feat: workflow library WIP - Save to library - Duplicate - Filter/sort - UI/queries * feat: workflow library - system graphs - wip * feat(backend): sync system workflows to db * fix: merge conflicts * feat: simplify default workflows - Rename "system" -> "default" - Simplify syncing logic - Update UI to match * feat(workflows): update default workflows - Update TextToImage_SD15 - Add TextToImage_SDXL - Add README * feat(ui): refine workflow list UI * fix(workflow_records): typo * fix(tests): fix tests * feat(ui): clean up workflow library hooks * fix(db): fix mis-ordered db cleanup step It was happening before pruning queue items - should happen afterwards, else you have to restart the app again to free disk space made available by the pruning. * feat(ui): tweak reset workflow editor translations * feat(ui): split out workflow redux state The `nodes` slice is a rather complicated slice. Removing `workflow` makes it a bit more reasonable. Also helps to flatten state out a bit. * docs: update default workflows README * fix: tidy up unused files, unrelated changes * fix(backend): revert unrelated service organisational changes * feat(backend): workflow_records.get_many arg "filter_text" -> "query" * feat(ui): use custom hook in current image buttons Already in use elsewhere, forgot to use it here. * fix(ui): remove commented out property * fix(ui): fix workflow loading - Different handling for loading from library vs external - Fix bug where only nodes and edges loaded * fix(ui): fix save/save-as workflow naming * fix(ui): fix circular dependency * fix(db): fix bug with releasing without lock in db.clean() * fix(db): remove extraneous lock * chore: bump ruff * fix(workflow_records): default `category` to `WorkflowCategory.User` This allows old workflows to validate when reading them from the db or image files. * hide workflow library buttons if feature is disabled --------- Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2023-12-08 22:48:38 +00:00
has_integer = len([i for i in all_invocations if i.model_fields.get("type").default == "integer"]) == 1
has_string = len([i for i in all_invocations if i.model_fields.get("type").default == "string"]) == 1
has_float = len([i for i in all_invocations if i.model_fields.get("type").default == "float"]) == 1
assert has_integer
assert has_string
assert not has_float