diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 902af0c02c..93f02b3446 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -1,45 +1,47 @@ # Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team -import asyncio -import logging -import socket -from inspect import signature -from pathlib import Path - -import uvicorn -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html -from fastapi.openapi.utils import get_openapi -from fastapi.staticfiles import StaticFiles -from fastapi_events.handlers.local import local_handler -from fastapi_events.middleware import EventHandlerASGIMiddleware -from pydantic.schema import schema - from .services.config import InvokeAIAppConfig -from ..backend.util.logging import InvokeAILogger - -from invokeai.version.invokeai_version import __version__ - -import invokeai.frontend.web as web_dir -import mimetypes - -from .api.dependencies import ApiDependencies -from .api.routers import sessions, models, images, boards, board_images, app_info -from .api.sockets import SocketIO -from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase - -import torch - -# noinspection PyUnresolvedReferences -import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) - -if torch.backends.mps.is_available(): - # noinspection PyUnresolvedReferences - import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import) - +# parse_args() must be called before any other imports. if it is not called first, consumers of the config +# which are imported/used before parse_args() is called will get the default config values instead of the +# values from the command line or config file. app_config = InvokeAIAppConfig.get_config() app_config.parse_args() + +if True: # hack to make flake8 happy with imports coming after setting up the config + import asyncio + import logging + import mimetypes + import socket + from inspect import signature + from pathlib import Path + + import torch + import uvicorn + from fastapi import FastAPI + from fastapi.middleware.cors import CORSMiddleware + from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html + from fastapi.openapi.utils import get_openapi + from fastapi.staticfiles import StaticFiles + from fastapi_events.handlers.local import local_handler + from fastapi_events.middleware import EventHandlerASGIMiddleware + from pydantic.schema import schema + + # noinspection PyUnresolvedReferences + import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) + import invokeai.frontend.web as web_dir + from invokeai.version.invokeai_version import __version__ + + from ..backend.util.logging import InvokeAILogger + from .api.dependencies import ApiDependencies + from .api.routers import app_info, board_images, boards, images, models, sessions + from .api.sockets import SocketIO + from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField + + if torch.backends.mps.is_available(): + # noinspection PyUnresolvedReferences + import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import) + + logger = InvokeAILogger.getLogger(config=app_config) # fix for windows mimetypes registry entries being borked diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index dfde3d1f58..3b8410a88e 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -1,67 +1,64 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -import argparse -import re -import shlex -import sys -import time -from typing import Union, get_type_hints, Optional - -from pydantic import BaseModel, ValidationError -from pydantic.fields import Field - -# This should come early so that the logger can pick up its configuration options from .services.config import InvokeAIAppConfig -from invokeai.backend.util.logging import InvokeAILogger -from invokeai.version.invokeai_version import __version__ - - -from invokeai.app.services.board_image_record_storage import ( - SqliteBoardImageRecordStorage, -) -from invokeai.app.services.board_images import ( - BoardImagesService, - BoardImagesServiceDependencies, -) -from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage -from invokeai.app.services.boards import BoardService, BoardServiceDependencies -from invokeai.app.services.image_record_storage import SqliteImageRecordStorage -from invokeai.app.services.images import ImageService, ImageServiceDependencies -from invokeai.app.services.resource_name import SimpleNameService -from invokeai.app.services.urls import LocalUrlService -from invokeai.app.services.invocation_stats import InvocationStatsService -from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs -from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage - -from .cli.commands import BaseCommand, CliContext, ExitCli, SortedHelpFormatter, add_graph_parsers, add_parsers -from .cli.completer import set_autocompleter -from .invocations.baseinvocation import BaseInvocation -from .services.events import EventServiceBase -from .services.graph import ( - Edge, - EdgeConnection, - GraphExecutionState, - GraphInvocation, - LibraryGraph, - are_connection_types_compatible, -) -from .services.image_file_storage import DiskImageFileStorage -from .services.invocation_queue import MemoryInvocationQueue -from .services.invocation_services import InvocationServices -from .services.invoker import Invoker -from .services.model_manager_service import ModelManagerService -from .services.processor import DefaultInvocationProcessor -from .services.sqlite import SqliteItemStorage - -import torch -import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) - -if torch.backends.mps.is_available(): - import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import) - +# parse_args() must be called before any other imports. if it is not called first, consumers of the config +# which are imported/used before parse_args() is called will get the default config values instead of the +# values from the command line or config file. config = InvokeAIAppConfig.get_config() config.parse_args() + +if True: # hack to make flake8 happy with imports coming after setting up the config + import argparse + import re + import shlex + import sys + import time + from typing import Optional, Union, get_type_hints + + import torch + from pydantic import BaseModel, ValidationError + from pydantic.fields import Field + + import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) + from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage + from invokeai.app.services.board_images import BoardImagesService, BoardImagesServiceDependencies + from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage + from invokeai.app.services.boards import BoardService, BoardServiceDependencies + from invokeai.app.services.image_record_storage import SqliteImageRecordStorage + from invokeai.app.services.images import ImageService, ImageServiceDependencies + from invokeai.app.services.invocation_stats import InvocationStatsService + from invokeai.app.services.resource_name import SimpleNameService + from invokeai.app.services.urls import LocalUrlService + from invokeai.backend.util.logging import InvokeAILogger + from invokeai.version.invokeai_version import __version__ + + from .cli.commands import BaseCommand, CliContext, ExitCli, SortedHelpFormatter, add_graph_parsers, add_parsers + from .cli.completer import set_autocompleter + from .invocations.baseinvocation import BaseInvocation + from .services.default_graphs import create_system_graphs, default_text_to_image_graph_id + from .services.events import EventServiceBase + from .services.graph import ( + Edge, + EdgeConnection, + GraphExecutionState, + GraphInvocation, + LibraryGraph, + are_connection_types_compatible, + ) + from .services.image_file_storage import DiskImageFileStorage + from .services.invocation_queue import MemoryInvocationQueue + from .services.invocation_services import InvocationServices + from .services.invoker import Invoker + from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage + from .services.model_manager_service import ModelManagerService + from .services.processor import DefaultInvocationProcessor + from .services.sqlite import SqliteItemStorage + + if torch.backends.mps.is_available(): + import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import) + + logger = InvokeAILogger().getLogger(config=config) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 65a8734690..9084a7bf48 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -28,6 +28,8 @@ from pydantic.fields import Undefined, ModelField from pydantic.typing import NoArgAnyCallable import semver +from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig + if TYPE_CHECKING: from ..services.invocation_services import InvocationServices @@ -470,6 +472,7 @@ class BaseInvocation(ABC, BaseModel): @classmethod def get_all_subclasses(cls): + app_config = InvokeAIAppConfig.get_config() subclasses = [] toprocess = [cls] while len(toprocess) > 0: @@ -477,7 +480,23 @@ class BaseInvocation(ABC, BaseModel): next_subclasses = next.__subclasses__() subclasses.extend(next_subclasses) toprocess.extend(next_subclasses) - return subclasses + allowed_invocations = [] + for sc in subclasses: + is_in_allowlist = ( + sc.__fields__.get("type").default in app_config.allow_nodes + if isinstance(app_config.allow_nodes, list) + else True + ) + + is_in_denylist = ( + sc.__fields__.get("type").default in app_config.deny_nodes + if isinstance(app_config.deny_nodes, list) + else False + ) + + if is_in_allowlist and not is_in_denylist: + allowed_invocations.append(sc) + return allowed_invocations @classmethod def get_invocations(cls): diff --git a/invokeai/app/services/config/base.py b/invokeai/app/services/config/base.py index b83621c708..33fd87b03a 100644 --- a/invokeai/app/services/config/base.py +++ b/invokeai/app/services/config/base.py @@ -42,7 +42,9 @@ class InvokeAISettings(BaseSettings): def parse_args(self, argv: list = sys.argv[1:]): parser = self.get_parser() - opt = parser.parse_args(argv) + opt, unknown_opts = parser.parse_known_args(argv) + if len(unknown_opts) > 0: + print("Unknown args:", unknown_opts) for name in self.__fields__: if name not in self._excluded(): value = getattr(opt, name) diff --git a/invokeai/app/services/config/invokeai_config.py b/invokeai/app/services/config/invokeai_config.py index 1a2c22c89a..7b687a28a1 100644 --- a/invokeai/app/services/config/invokeai_config.py +++ b/invokeai/app/services/config/invokeai_config.py @@ -254,6 +254,10 @@ class InvokeAIAppConfig(InvokeAISettings): attention_slice_size: Literal[tuple(["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8])] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", ) force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",) + # NODES + allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", category="Nodes") + deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", category="Nodes") + # DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance') free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", category='Memory/Performance') diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts index dd86c77735..5599913a18 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts @@ -9,13 +9,17 @@ import { startAppListening } from '..'; export const addReceivedOpenAPISchemaListener = () => { startAppListening({ actionCreator: receivedOpenAPISchema.fulfilled, - effect: (action, { dispatch }) => { + effect: (action, { dispatch, getState }) => { const log = logger('system'); const schemaJSON = action.payload; log.debug({ schemaJSON }, 'Received OpenAPI schema'); - - const nodeTemplates = parseSchema(schemaJSON); + const { nodesAllowlist, nodesDenylist } = getState().config; + const nodeTemplates = parseSchema( + schemaJSON, + nodesAllowlist, + nodesDenylist + ); log.debug( { nodeTemplates: parseify(nodeTemplates) }, diff --git a/invokeai/frontend/web/src/app/types/invokeai.ts b/invokeai/frontend/web/src/app/types/invokeai.ts index 8b247327c7..a02c16cf7a 100644 --- a/invokeai/frontend/web/src/app/types/invokeai.ts +++ b/invokeai/frontend/web/src/app/types/invokeai.ts @@ -50,6 +50,8 @@ export type AppConfig = { disabledFeatures: AppFeature[]; disabledSDFeatures: SDFeature[]; canRestoreDeletedImagesFromBin: boolean; + nodesAllowlist: string[] | undefined; + nodesDenylist: string[] | undefined; sd: { defaultModel?: string; disabledControlNetModels: string[]; diff --git a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts index 553d0770aa..8615a12c46 100644 --- a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts +++ b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts @@ -60,11 +60,23 @@ const isNotInDenylist = (schema: InvocationSchemaObject) => !invocationDenylist.includes(schema.properties.type.default); export const parseSchema = ( - openAPI: OpenAPIV3.Document + openAPI: OpenAPIV3.Document, + nodesAllowlistExtra: string[] | undefined = undefined, + nodesDenylistExtra: string[] | undefined = undefined ): Record => { const filteredSchemas = Object.values(openAPI.components?.schemas ?? {}) .filter(isInvocationSchemaObject) - .filter(isNotInDenylist); + .filter(isNotInDenylist) + .filter((schema) => + nodesAllowlistExtra + ? nodesAllowlistExtra.includes(schema.properties.type.default) + : true + ) + .filter((schema) => + nodesDenylistExtra + ? !nodesDenylistExtra.includes(schema.properties.type.default) + : true + ); const invocations = filteredSchemas.reduce< Record diff --git a/invokeai/frontend/web/src/features/system/store/configSlice.ts b/invokeai/frontend/web/src/features/system/store/configSlice.ts index 8ad0ee33e4..36a61be969 100644 --- a/invokeai/frontend/web/src/features/system/store/configSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/configSlice.ts @@ -15,6 +15,8 @@ export const initialConfigState: AppConfig = { 'perlinNoise', 'noiseThreshold', ], + nodesAllowlist: undefined, + nodesDenylist: undefined, canRestoreDeletedImagesFromBin: true, sd: { disabledControlNetModels: [], diff --git a/invokeai/frontend/web/src/features/system/store/systemSlice.ts b/invokeai/frontend/web/src/features/system/store/systemSlice.ts index bf8036ba98..022762bc78 100644 --- a/invokeai/frontend/web/src/features/system/store/systemSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/systemSlice.ts @@ -3,7 +3,7 @@ import { PayloadAction, createSlice, isAnyOf } from '@reduxjs/toolkit'; import { InvokeLogLevel } from 'app/logging/logger'; import { userInvoked } from 'app/store/actions'; import { t } from 'i18next'; -import { get, startCase, upperFirst } from 'lodash-es'; +import { get, startCase, truncate, upperFirst } from 'lodash-es'; import { LogLevelName } from 'roarr'; import { isAnySessionRejected, @@ -357,10 +357,13 @@ export const systemSlice = createSlice({ result.data.error.detail.map((e) => { state.toastQueue.push( makeToast({ - title: upperFirst(e.msg), + title: truncate(upperFirst(e.msg), { length: 128 }), status: 'error', - description: `Path: - ${e.loc.slice(3).join('.')}`, + description: truncate( + `Path: + ${e.loc.join('.')}`, + { length: 128 } + ), duration, }) ); @@ -375,7 +378,10 @@ export const systemSlice = createSlice({ makeToast({ title: t('toast.serverError'), status: 'error', - description: get(errorDescription, 'detail', 'Unknown Error'), + description: truncate( + get(errorDescription, 'detail', 'Unknown Error'), + { length: 128 } + ), duration, }) ); diff --git a/tests/test_config.py b/tests/test_config.py index 88da7a02ab..3c1646d860 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,9 +1,10 @@ import os +from pathlib import Path from typing import Any import pytest from omegaconf import OmegaConf -from pathlib import Path +from pydantic import ValidationError from invokeai.app.services.config import InvokeAIAppConfig @@ -147,3 +148,58 @@ def test_type_coercion(patch_rootdir): conf.parse_args(argv=["--root=/tmp/foobar"]) assert conf.root == Path("/tmp/different") assert isinstance(conf.root, Path) + + +@pytest.mark.xfail( + reason=""" + This test fails when run as part of the full test suite. + + This test needs to deny nodes from being included in the InvocationsUnion by providing + an app configuration as a test fixture. Pytest executes all test files before running + tests, so the app configuration is already initialized by the time this test runs, and + the InvocationUnion is already created and the denied nodes are not omitted from it. + + This test passes when `test_config.py` is tested in isolation. + + Perhaps a solution would be to call `InvokeAIAppConfig.get_config().parse_args()` in + other test files? + """ +) +def test_deny_nodes(patch_rootdir): + # Allow integer, string and float, but explicitly deny float + allow_deny_nodes_conf = OmegaConf.create( + """ + InvokeAI: + Nodes: + allow_nodes: + - integer + - string + - float + deny_nodes: + - float + """ + ) + # must parse config before importing Graph, so its nodes union uses the config + conf = InvokeAIAppConfig().get_config() + conf.parse_args(conf=allow_deny_nodes_conf, argv=[]) + from invokeai.app.services.graph import Graph + + # confirm graph validation fails when using denied node + Graph(nodes={"1": {"id": "1", "type": "integer"}}) + Graph(nodes={"1": {"id": "1", "type": "string"}}) + + with pytest.raises(ValidationError): + Graph(nodes={"1": {"id": "1", "type": "float"}}) + + from invokeai.app.invocations.baseinvocation import BaseInvocation + + # confirm invocations union will not have denied nodes + all_invocations = BaseInvocation.get_invocations() + + has_integer = len([i for i in all_invocations if i.__fields__.get("type").default == "integer"]) == 1 + has_string = len([i for i in all_invocations if i.__fields__.get("type").default == "string"]) == 1 + has_float = len([i for i in all_invocations if i.__fields__.get("type").default == "float"]) == 1 + + assert has_integer + assert has_string + assert not has_float