From 4395ee3c03be9fc17de13f458c3954368d620400 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 8 Sep 2023 10:41:00 +1000 Subject: [PATCH] feat: parse config before importing anything else We need to parse the config before doing anything related to invocations to ensure that the invocations union picks up on denied nodes. - Move that to the top of api_app and cli_app - Wrap subsequent imports in `if True:`, as a hack to satisfy flake8 and not have to noqa every line or the whole file - Add tests to ensure graph validation fails when using a denied node, and that the invocations union does not have denied nodes (this indirectly provides confidence that the generated OpenAPI schema will not include denied nodes) --- invokeai/app/api_app.py | 76 ++++++++++++++------------- invokeai/app/cli_app.py | 113 +++++++++++++++++++--------------------- tests/test_config.py | 41 +++++++++++++++ 3 files changed, 135 insertions(+), 95 deletions(-) diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 902af0c02c..93f02b3446 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -1,45 +1,47 @@ # Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team -import asyncio -import logging -import socket -from inspect import signature -from pathlib import Path - -import uvicorn -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html -from fastapi.openapi.utils import get_openapi -from fastapi.staticfiles import StaticFiles -from fastapi_events.handlers.local import local_handler -from fastapi_events.middleware import EventHandlerASGIMiddleware -from pydantic.schema import schema - from .services.config import InvokeAIAppConfig -from ..backend.util.logging import InvokeAILogger - -from invokeai.version.invokeai_version import __version__ - -import invokeai.frontend.web as web_dir -import mimetypes - -from .api.dependencies import ApiDependencies -from .api.routers import sessions, models, images, boards, board_images, app_info -from .api.sockets import SocketIO -from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase - -import torch - -# noinspection PyUnresolvedReferences -import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) - -if torch.backends.mps.is_available(): - # noinspection PyUnresolvedReferences - import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import) - +# parse_args() must be called before any other imports. if it is not called first, consumers of the config +# which are imported/used before parse_args() is called will get the default config values instead of the +# values from the command line or config file. app_config = InvokeAIAppConfig.get_config() app_config.parse_args() + +if True: # hack to make flake8 happy with imports coming after setting up the config + import asyncio + import logging + import mimetypes + import socket + from inspect import signature + from pathlib import Path + + import torch + import uvicorn + from fastapi import FastAPI + from fastapi.middleware.cors import CORSMiddleware + from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html + from fastapi.openapi.utils import get_openapi + from fastapi.staticfiles import StaticFiles + from fastapi_events.handlers.local import local_handler + from fastapi_events.middleware import EventHandlerASGIMiddleware + from pydantic.schema import schema + + # noinspection PyUnresolvedReferences + import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) + import invokeai.frontend.web as web_dir + from invokeai.version.invokeai_version import __version__ + + from ..backend.util.logging import InvokeAILogger + from .api.dependencies import ApiDependencies + from .api.routers import app_info, board_images, boards, images, models, sessions + from .api.sockets import SocketIO + from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField + + if torch.backends.mps.is_available(): + # noinspection PyUnresolvedReferences + import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import) + + logger = InvokeAILogger.getLogger(config=app_config) # fix for windows mimetypes registry entries being borked diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index dfde3d1f58..3b8410a88e 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -1,67 +1,64 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -import argparse -import re -import shlex -import sys -import time -from typing import Union, get_type_hints, Optional - -from pydantic import BaseModel, ValidationError -from pydantic.fields import Field - -# This should come early so that the logger can pick up its configuration options from .services.config import InvokeAIAppConfig -from invokeai.backend.util.logging import InvokeAILogger -from invokeai.version.invokeai_version import __version__ - - -from invokeai.app.services.board_image_record_storage import ( - SqliteBoardImageRecordStorage, -) -from invokeai.app.services.board_images import ( - BoardImagesService, - BoardImagesServiceDependencies, -) -from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage -from invokeai.app.services.boards import BoardService, BoardServiceDependencies -from invokeai.app.services.image_record_storage import SqliteImageRecordStorage -from invokeai.app.services.images import ImageService, ImageServiceDependencies -from invokeai.app.services.resource_name import SimpleNameService -from invokeai.app.services.urls import LocalUrlService -from invokeai.app.services.invocation_stats import InvocationStatsService -from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs -from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage - -from .cli.commands import BaseCommand, CliContext, ExitCli, SortedHelpFormatter, add_graph_parsers, add_parsers -from .cli.completer import set_autocompleter -from .invocations.baseinvocation import BaseInvocation -from .services.events import EventServiceBase -from .services.graph import ( - Edge, - EdgeConnection, - GraphExecutionState, - GraphInvocation, - LibraryGraph, - are_connection_types_compatible, -) -from .services.image_file_storage import DiskImageFileStorage -from .services.invocation_queue import MemoryInvocationQueue -from .services.invocation_services import InvocationServices -from .services.invoker import Invoker -from .services.model_manager_service import ModelManagerService -from .services.processor import DefaultInvocationProcessor -from .services.sqlite import SqliteItemStorage - -import torch -import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) - -if torch.backends.mps.is_available(): - import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import) - +# parse_args() must be called before any other imports. if it is not called first, consumers of the config +# which are imported/used before parse_args() is called will get the default config values instead of the +# values from the command line or config file. config = InvokeAIAppConfig.get_config() config.parse_args() + +if True: # hack to make flake8 happy with imports coming after setting up the config + import argparse + import re + import shlex + import sys + import time + from typing import Optional, Union, get_type_hints + + import torch + from pydantic import BaseModel, ValidationError + from pydantic.fields import Field + + import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) + from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage + from invokeai.app.services.board_images import BoardImagesService, BoardImagesServiceDependencies + from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage + from invokeai.app.services.boards import BoardService, BoardServiceDependencies + from invokeai.app.services.image_record_storage import SqliteImageRecordStorage + from invokeai.app.services.images import ImageService, ImageServiceDependencies + from invokeai.app.services.invocation_stats import InvocationStatsService + from invokeai.app.services.resource_name import SimpleNameService + from invokeai.app.services.urls import LocalUrlService + from invokeai.backend.util.logging import InvokeAILogger + from invokeai.version.invokeai_version import __version__ + + from .cli.commands import BaseCommand, CliContext, ExitCli, SortedHelpFormatter, add_graph_parsers, add_parsers + from .cli.completer import set_autocompleter + from .invocations.baseinvocation import BaseInvocation + from .services.default_graphs import create_system_graphs, default_text_to_image_graph_id + from .services.events import EventServiceBase + from .services.graph import ( + Edge, + EdgeConnection, + GraphExecutionState, + GraphInvocation, + LibraryGraph, + are_connection_types_compatible, + ) + from .services.image_file_storage import DiskImageFileStorage + from .services.invocation_queue import MemoryInvocationQueue + from .services.invocation_services import InvocationServices + from .services.invoker import Invoker + from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage + from .services.model_manager_service import ModelManagerService + from .services.processor import DefaultInvocationProcessor + from .services.sqlite import SqliteItemStorage + + if torch.backends.mps.is_available(): + import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import) + + logger = InvokeAILogger().getLogger(config=config) diff --git a/tests/test_config.py b/tests/test_config.py index 88da7a02ab..7805be4fd2 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,5 +1,6 @@ import os from typing import Any +from pydantic import ValidationError import pytest from omegaconf import OmegaConf @@ -147,3 +148,43 @@ def test_type_coercion(patch_rootdir): conf.parse_args(argv=["--root=/tmp/foobar"]) assert conf.root == Path("/tmp/different") assert isinstance(conf.root, Path) + + +def test_deny_nodes(patch_rootdir): + # Allow integer, string and float, but explicitly deny float + allow_deny_nodes_conf = OmegaConf.create( + """ + InvokeAI: + Nodes: + allow_nodes: + - integer + - string + - float + deny_nodes: + - float + """ + ) + # must parse config before importing Graph, so its nodes union uses the config + conf = InvokeAIAppConfig().get_config() + conf.parse_args(conf=allow_deny_nodes_conf, argv=[]) + from invokeai.app.services.graph import Graph + + # confirm graph validation fails when using denied node + Graph(nodes={"1": {"id": "1", "type": "integer"}}) + Graph(nodes={"1": {"id": "1", "type": "string"}}) + + with pytest.raises(ValidationError): + Graph(nodes={"1": {"id": "1", "type": "float"}}) + + from invokeai.app.invocations.baseinvocation import BaseInvocation + + # confirm invocations union will not have denied nodes + all_invocations = BaseInvocation.get_invocations() + + has_integer = len([i for i in all_invocations if i.__fields__.get("type").default == "integer"]) == 1 + has_string = len([i for i in all_invocations if i.__fields__.get("type").default == "string"]) == 1 + has_float = len([i for i in all_invocations if i.__fields__.get("type").default == "float"]) == 1 + + assert has_integer + assert has_string + assert not has_float