mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: parse config before importing anything else
We need to parse the config before doing anything related to invocations to ensure that the invocations union picks up on denied nodes. - Move that to the top of api_app and cli_app - Wrap subsequent imports in `if True:`, as a hack to satisfy flake8 and not have to noqa every line or the whole file - Add tests to ensure graph validation fails when using a denied node, and that the invocations union does not have denied nodes (this indirectly provides confidence that the generated OpenAPI schema will not include denied nodes)
This commit is contained in:
parent
1d2636aa90
commit
4395ee3c03
@ -1,45 +1,47 @@
|
|||||||
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import socket
|
|
||||||
from inspect import signature
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import uvicorn
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
|
||||||
from fastapi.openapi.utils import get_openapi
|
|
||||||
from fastapi.staticfiles import StaticFiles
|
|
||||||
from fastapi_events.handlers.local import local_handler
|
|
||||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
|
||||||
from pydantic.schema import schema
|
|
||||||
|
|
||||||
from .services.config import InvokeAIAppConfig
|
from .services.config import InvokeAIAppConfig
|
||||||
from ..backend.util.logging import InvokeAILogger
|
|
||||||
|
|
||||||
from invokeai.version.invokeai_version import __version__
|
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
|
||||||
|
# which are imported/used before parse_args() is called will get the default config values instead of the
|
||||||
|
# values from the command line or config file.
|
||||||
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
|
app_config.parse_args()
|
||||||
|
|
||||||
import invokeai.frontend.web as web_dir
|
if True: # hack to make flake8 happy with imports coming after setting up the config
|
||||||
import mimetypes
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import mimetypes
|
||||||
|
import socket
|
||||||
|
from inspect import signature
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from .api.dependencies import ApiDependencies
|
import torch
|
||||||
from .api.routers import sessions, models, images, boards, board_images, app_info
|
import uvicorn
|
||||||
from .api.sockets import SocketIO
|
from fastapi import FastAPI
|
||||||
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||||
|
from fastapi.openapi.utils import get_openapi
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from fastapi_events.handlers.local import local_handler
|
||||||
|
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||||
|
from pydantic.schema import schema
|
||||||
|
|
||||||
import torch
|
# noinspection PyUnresolvedReferences
|
||||||
|
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||||
|
import invokeai.frontend.web as web_dir
|
||||||
|
from invokeai.version.invokeai_version import __version__
|
||||||
|
|
||||||
# noinspection PyUnresolvedReferences
|
from ..backend.util.logging import InvokeAILogger
|
||||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
from .api.dependencies import ApiDependencies
|
||||||
|
from .api.routers import app_info, board_images, boards, images, models, sessions
|
||||||
|
from .api.sockets import SocketIO
|
||||||
|
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
|
||||||
|
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||||
|
|
||||||
|
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
|
||||||
app_config.parse_args()
|
|
||||||
logger = InvokeAILogger.getLogger(config=app_config)
|
logger = InvokeAILogger.getLogger(config=app_config)
|
||||||
|
|
||||||
# fix for windows mimetypes registry entries being borked
|
# fix for windows mimetypes registry entries being borked
|
||||||
|
@ -1,67 +1,64 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
import argparse
|
|
||||||
import re
|
|
||||||
import shlex
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from typing import Union, get_type_hints, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ValidationError
|
|
||||||
from pydantic.fields import Field
|
|
||||||
|
|
||||||
# This should come early so that the logger can pick up its configuration options
|
|
||||||
from .services.config import InvokeAIAppConfig
|
from .services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
|
||||||
from invokeai.version.invokeai_version import __version__
|
|
||||||
|
|
||||||
|
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
|
||||||
|
# which are imported/used before parse_args() is called will get the default config values instead of the
|
||||||
|
# values from the command line or config file.
|
||||||
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
config.parse_args()
|
||||||
|
|
||||||
from invokeai.app.services.board_image_record_storage import (
|
if True: # hack to make flake8 happy with imports coming after setting up the config
|
||||||
SqliteBoardImageRecordStorage,
|
import argparse
|
||||||
)
|
import re
|
||||||
from invokeai.app.services.board_images import (
|
import shlex
|
||||||
BoardImagesService,
|
import sys
|
||||||
BoardImagesServiceDependencies,
|
import time
|
||||||
)
|
from typing import Optional, Union, get_type_hints
|
||||||
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
|
||||||
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
|
||||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
|
||||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
|
||||||
from invokeai.app.services.resource_name import SimpleNameService
|
|
||||||
from invokeai.app.services.urls import LocalUrlService
|
|
||||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
|
||||||
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
|
||||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
|
||||||
|
|
||||||
from .cli.commands import BaseCommand, CliContext, ExitCli, SortedHelpFormatter, add_graph_parsers, add_parsers
|
import torch
|
||||||
from .cli.completer import set_autocompleter
|
from pydantic import BaseModel, ValidationError
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from pydantic.fields import Field
|
||||||
from .services.events import EventServiceBase
|
|
||||||
from .services.graph import (
|
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||||
|
from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage
|
||||||
|
from invokeai.app.services.board_images import BoardImagesService, BoardImagesServiceDependencies
|
||||||
|
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
||||||
|
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
||||||
|
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||||
|
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||||
|
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||||
|
from invokeai.app.services.resource_name import SimpleNameService
|
||||||
|
from invokeai.app.services.urls import LocalUrlService
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
from invokeai.version.invokeai_version import __version__
|
||||||
|
|
||||||
|
from .cli.commands import BaseCommand, CliContext, ExitCli, SortedHelpFormatter, add_graph_parsers, add_parsers
|
||||||
|
from .cli.completer import set_autocompleter
|
||||||
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
|
from .services.default_graphs import create_system_graphs, default_text_to_image_graph_id
|
||||||
|
from .services.events import EventServiceBase
|
||||||
|
from .services.graph import (
|
||||||
Edge,
|
Edge,
|
||||||
EdgeConnection,
|
EdgeConnection,
|
||||||
GraphExecutionState,
|
GraphExecutionState,
|
||||||
GraphInvocation,
|
GraphInvocation,
|
||||||
LibraryGraph,
|
LibraryGraph,
|
||||||
are_connection_types_compatible,
|
are_connection_types_compatible,
|
||||||
)
|
)
|
||||||
from .services.image_file_storage import DiskImageFileStorage
|
from .services.image_file_storage import DiskImageFileStorage
|
||||||
from .services.invocation_queue import MemoryInvocationQueue
|
from .services.invocation_queue import MemoryInvocationQueue
|
||||||
from .services.invocation_services import InvocationServices
|
from .services.invocation_services import InvocationServices
|
||||||
from .services.invoker import Invoker
|
from .services.invoker import Invoker
|
||||||
from .services.model_manager_service import ModelManagerService
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
from .services.processor import DefaultInvocationProcessor
|
from .services.model_manager_service import ModelManagerService
|
||||||
from .services.sqlite import SqliteItemStorage
|
from .services.processor import DefaultInvocationProcessor
|
||||||
|
from .services.sqlite import SqliteItemStorage
|
||||||
|
|
||||||
import torch
|
if torch.backends.mps.is_available():
|
||||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
|
||||||
|
|
||||||
if torch.backends.mps.is_available():
|
|
||||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||||
|
|
||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
|
||||||
config.parse_args()
|
|
||||||
logger = InvokeAILogger().getLogger(config=config)
|
logger = InvokeAILogger().getLogger(config=config)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
@ -147,3 +148,43 @@ def test_type_coercion(patch_rootdir):
|
|||||||
conf.parse_args(argv=["--root=/tmp/foobar"])
|
conf.parse_args(argv=["--root=/tmp/foobar"])
|
||||||
assert conf.root == Path("/tmp/different")
|
assert conf.root == Path("/tmp/different")
|
||||||
assert isinstance(conf.root, Path)
|
assert isinstance(conf.root, Path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_deny_nodes(patch_rootdir):
|
||||||
|
# Allow integer, string and float, but explicitly deny float
|
||||||
|
allow_deny_nodes_conf = OmegaConf.create(
|
||||||
|
"""
|
||||||
|
InvokeAI:
|
||||||
|
Nodes:
|
||||||
|
allow_nodes:
|
||||||
|
- integer
|
||||||
|
- string
|
||||||
|
- float
|
||||||
|
deny_nodes:
|
||||||
|
- float
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
# must parse config before importing Graph, so its nodes union uses the config
|
||||||
|
conf = InvokeAIAppConfig().get_config()
|
||||||
|
conf.parse_args(conf=allow_deny_nodes_conf, argv=[])
|
||||||
|
from invokeai.app.services.graph import Graph
|
||||||
|
|
||||||
|
# confirm graph validation fails when using denied node
|
||||||
|
Graph(nodes={"1": {"id": "1", "type": "integer"}})
|
||||||
|
Graph(nodes={"1": {"id": "1", "type": "string"}})
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Graph(nodes={"1": {"id": "1", "type": "float"}})
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
|
|
||||||
|
# confirm invocations union will not have denied nodes
|
||||||
|
all_invocations = BaseInvocation.get_invocations()
|
||||||
|
|
||||||
|
has_integer = len([i for i in all_invocations if i.__fields__.get("type").default == "integer"]) == 1
|
||||||
|
has_string = len([i for i in all_invocations if i.__fields__.get("type").default == "string"]) == 1
|
||||||
|
has_float = len([i for i in all_invocations if i.__fields__.get("type").default == "float"]) == 1
|
||||||
|
|
||||||
|
assert has_integer
|
||||||
|
assert has_string
|
||||||
|
assert not has_float
|
||||||
|
Loading…
Reference in New Issue
Block a user