mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: parse config before importing anything else
We need to parse the config before doing anything related to invocations to ensure that the invocations union picks up on denied nodes. - Move that to the top of api_app and cli_app - Wrap subsequent imports in `if True:`, as a hack to satisfy flake8 and not have to noqa every line or the whole file - Add tests to ensure graph validation fails when using a denied node, and that the invocations union does not have denied nodes (this indirectly provides confidence that the generated OpenAPI schema will not include denied nodes)
This commit is contained in:
parent
1d2636aa90
commit
4395ee3c03
@ -1,10 +1,21 @@
|
|||||||
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||||
|
from .services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
|
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
|
||||||
|
# which are imported/used before parse_args() is called will get the default config values instead of the
|
||||||
|
# values from the command line or config file.
|
||||||
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
|
app_config.parse_args()
|
||||||
|
|
||||||
|
if True: # hack to make flake8 happy with imports coming after setting up the config
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import mimetypes
|
||||||
import socket
|
import socket
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
@ -15,31 +26,22 @@ from fastapi_events.handlers.local import local_handler
|
|||||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||||
from pydantic.schema import schema
|
from pydantic.schema import schema
|
||||||
|
|
||||||
from .services.config import InvokeAIAppConfig
|
|
||||||
from ..backend.util.logging import InvokeAILogger
|
|
||||||
|
|
||||||
from invokeai.version.invokeai_version import __version__
|
|
||||||
|
|
||||||
import invokeai.frontend.web as web_dir
|
|
||||||
import mimetypes
|
|
||||||
|
|
||||||
from .api.dependencies import ApiDependencies
|
|
||||||
from .api.routers import sessions, models, images, boards, board_images, app_info
|
|
||||||
from .api.sockets import SocketIO
|
|
||||||
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||||
|
import invokeai.frontend.web as web_dir
|
||||||
|
from invokeai.version.invokeai_version import __version__
|
||||||
|
|
||||||
|
from ..backend.util.logging import InvokeAILogger
|
||||||
|
from .api.dependencies import ApiDependencies
|
||||||
|
from .api.routers import app_info, board_images, boards, images, models, sessions
|
||||||
|
from .api.sockets import SocketIO
|
||||||
|
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
|
||||||
|
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||||
|
|
||||||
|
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
|
||||||
app_config.parse_args()
|
|
||||||
logger = InvokeAILogger.getLogger(config=app_config)
|
logger = InvokeAILogger.getLogger(config=app_config)
|
||||||
|
|
||||||
# fix for windows mimetypes registry entries being borked
|
# fix for windows mimetypes registry entries being borked
|
||||||
|
@ -1,41 +1,42 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from .services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
|
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
|
||||||
|
# which are imported/used before parse_args() is called will get the default config values instead of the
|
||||||
|
# values from the command line or config file.
|
||||||
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
config.parse_args()
|
||||||
|
|
||||||
|
if True: # hack to make flake8 happy with imports coming after setting up the config
|
||||||
import argparse
|
import argparse
|
||||||
import re
|
import re
|
||||||
import shlex
|
import shlex
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from typing import Union, get_type_hints, Optional
|
from typing import Optional, Union, get_type_hints
|
||||||
|
|
||||||
|
import torch
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
# This should come early so that the logger can pick up its configuration options
|
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||||
from .services.config import InvokeAIAppConfig
|
from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.app.services.board_images import BoardImagesService, BoardImagesServiceDependencies
|
||||||
from invokeai.version.invokeai_version import __version__
|
|
||||||
|
|
||||||
|
|
||||||
from invokeai.app.services.board_image_record_storage import (
|
|
||||||
SqliteBoardImageRecordStorage,
|
|
||||||
)
|
|
||||||
from invokeai.app.services.board_images import (
|
|
||||||
BoardImagesService,
|
|
||||||
BoardImagesServiceDependencies,
|
|
||||||
)
|
|
||||||
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
||||||
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
||||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||||
|
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||||
from invokeai.app.services.resource_name import SimpleNameService
|
from invokeai.app.services.resource_name import SimpleNameService
|
||||||
from invokeai.app.services.urls import LocalUrlService
|
from invokeai.app.services.urls import LocalUrlService
|
||||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
from invokeai.version.invokeai_version import __version__
|
||||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
|
||||||
|
|
||||||
from .cli.commands import BaseCommand, CliContext, ExitCli, SortedHelpFormatter, add_graph_parsers, add_parsers
|
from .cli.commands import BaseCommand, CliContext, ExitCli, SortedHelpFormatter, add_graph_parsers, add_parsers
|
||||||
from .cli.completer import set_autocompleter
|
from .cli.completer import set_autocompleter
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
|
from .services.default_graphs import create_system_graphs, default_text_to_image_graph_id
|
||||||
from .services.events import EventServiceBase
|
from .services.events import EventServiceBase
|
||||||
from .services.graph import (
|
from .services.graph import (
|
||||||
Edge,
|
Edge,
|
||||||
@ -49,19 +50,15 @@ from .services.image_file_storage import DiskImageFileStorage
|
|||||||
from .services.invocation_queue import MemoryInvocationQueue
|
from .services.invocation_queue import MemoryInvocationQueue
|
||||||
from .services.invocation_services import InvocationServices
|
from .services.invocation_services import InvocationServices
|
||||||
from .services.invoker import Invoker
|
from .services.invoker import Invoker
|
||||||
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
from .services.model_manager_service import ModelManagerService
|
from .services.model_manager_service import ModelManagerService
|
||||||
from .services.processor import DefaultInvocationProcessor
|
from .services.processor import DefaultInvocationProcessor
|
||||||
from .services.sqlite import SqliteItemStorage
|
from .services.sqlite import SqliteItemStorage
|
||||||
|
|
||||||
import torch
|
|
||||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
|
||||||
|
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||||
|
|
||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
|
||||||
config.parse_args()
|
|
||||||
logger = InvokeAILogger().getLogger(config=config)
|
logger = InvokeAILogger().getLogger(config=config)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
@ -147,3 +148,43 @@ def test_type_coercion(patch_rootdir):
|
|||||||
conf.parse_args(argv=["--root=/tmp/foobar"])
|
conf.parse_args(argv=["--root=/tmp/foobar"])
|
||||||
assert conf.root == Path("/tmp/different")
|
assert conf.root == Path("/tmp/different")
|
||||||
assert isinstance(conf.root, Path)
|
assert isinstance(conf.root, Path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_deny_nodes(patch_rootdir):
|
||||||
|
# Allow integer, string and float, but explicitly deny float
|
||||||
|
allow_deny_nodes_conf = OmegaConf.create(
|
||||||
|
"""
|
||||||
|
InvokeAI:
|
||||||
|
Nodes:
|
||||||
|
allow_nodes:
|
||||||
|
- integer
|
||||||
|
- string
|
||||||
|
- float
|
||||||
|
deny_nodes:
|
||||||
|
- float
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
# must parse config before importing Graph, so its nodes union uses the config
|
||||||
|
conf = InvokeAIAppConfig().get_config()
|
||||||
|
conf.parse_args(conf=allow_deny_nodes_conf, argv=[])
|
||||||
|
from invokeai.app.services.graph import Graph
|
||||||
|
|
||||||
|
# confirm graph validation fails when using denied node
|
||||||
|
Graph(nodes={"1": {"id": "1", "type": "integer"}})
|
||||||
|
Graph(nodes={"1": {"id": "1", "type": "string"}})
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Graph(nodes={"1": {"id": "1", "type": "float"}})
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
|
|
||||||
|
# confirm invocations union will not have denied nodes
|
||||||
|
all_invocations = BaseInvocation.get_invocations()
|
||||||
|
|
||||||
|
has_integer = len([i for i in all_invocations if i.__fields__.get("type").default == "integer"]) == 1
|
||||||
|
has_string = len([i for i in all_invocations if i.__fields__.get("type").default == "string"]) == 1
|
||||||
|
has_float = len([i for i in all_invocations if i.__fields__.get("type").default == "float"]) == 1
|
||||||
|
|
||||||
|
assert has_integer
|
||||||
|
assert has_string
|
||||||
|
assert not has_float
|
||||||
|
Loading…
Reference in New Issue
Block a user