mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into main
This commit is contained in:
commit
abc50ce88b
@ -1,10 +1,21 @@
|
|||||||
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||||
|
from .services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
|
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
|
||||||
|
# which are imported/used before parse_args() is called will get the default config values instead of the
|
||||||
|
# values from the command line or config file.
|
||||||
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
|
app_config.parse_args()
|
||||||
|
|
||||||
|
if True: # hack to make flake8 happy with imports coming after setting up the config
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import mimetypes
|
||||||
import socket
|
import socket
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
@ -15,31 +26,22 @@ from fastapi_events.handlers.local import local_handler
|
|||||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||||
from pydantic.schema import schema
|
from pydantic.schema import schema
|
||||||
|
|
||||||
from .services.config import InvokeAIAppConfig
|
|
||||||
from ..backend.util.logging import InvokeAILogger
|
|
||||||
|
|
||||||
from invokeai.version.invokeai_version import __version__
|
|
||||||
|
|
||||||
import invokeai.frontend.web as web_dir
|
|
||||||
import mimetypes
|
|
||||||
|
|
||||||
from .api.dependencies import ApiDependencies
|
|
||||||
from .api.routers import sessions, models, images, boards, board_images, app_info
|
|
||||||
from .api.sockets import SocketIO
|
|
||||||
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||||
|
import invokeai.frontend.web as web_dir
|
||||||
|
from invokeai.version.invokeai_version import __version__
|
||||||
|
|
||||||
|
from ..backend.util.logging import InvokeAILogger
|
||||||
|
from .api.dependencies import ApiDependencies
|
||||||
|
from .api.routers import app_info, board_images, boards, images, models, sessions
|
||||||
|
from .api.sockets import SocketIO
|
||||||
|
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
|
||||||
|
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||||
|
|
||||||
|
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
|
||||||
app_config.parse_args()
|
|
||||||
logger = InvokeAILogger.getLogger(config=app_config)
|
logger = InvokeAILogger.getLogger(config=app_config)
|
||||||
|
|
||||||
# fix for windows mimetypes registry entries being borked
|
# fix for windows mimetypes registry entries being borked
|
||||||
|
@ -1,41 +1,42 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from .services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
|
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
|
||||||
|
# which are imported/used before parse_args() is called will get the default config values instead of the
|
||||||
|
# values from the command line or config file.
|
||||||
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
config.parse_args()
|
||||||
|
|
||||||
|
if True: # hack to make flake8 happy with imports coming after setting up the config
|
||||||
import argparse
|
import argparse
|
||||||
import re
|
import re
|
||||||
import shlex
|
import shlex
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from typing import Union, get_type_hints, Optional
|
from typing import Optional, Union, get_type_hints
|
||||||
|
|
||||||
|
import torch
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
# This should come early so that the logger can pick up its configuration options
|
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||||
from .services.config import InvokeAIAppConfig
|
from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.app.services.board_images import BoardImagesService, BoardImagesServiceDependencies
|
||||||
from invokeai.version.invokeai_version import __version__
|
|
||||||
|
|
||||||
|
|
||||||
from invokeai.app.services.board_image_record_storage import (
|
|
||||||
SqliteBoardImageRecordStorage,
|
|
||||||
)
|
|
||||||
from invokeai.app.services.board_images import (
|
|
||||||
BoardImagesService,
|
|
||||||
BoardImagesServiceDependencies,
|
|
||||||
)
|
|
||||||
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
||||||
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
||||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||||
|
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||||
from invokeai.app.services.resource_name import SimpleNameService
|
from invokeai.app.services.resource_name import SimpleNameService
|
||||||
from invokeai.app.services.urls import LocalUrlService
|
from invokeai.app.services.urls import LocalUrlService
|
||||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
from invokeai.version.invokeai_version import __version__
|
||||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
|
||||||
|
|
||||||
from .cli.commands import BaseCommand, CliContext, ExitCli, SortedHelpFormatter, add_graph_parsers, add_parsers
|
from .cli.commands import BaseCommand, CliContext, ExitCli, SortedHelpFormatter, add_graph_parsers, add_parsers
|
||||||
from .cli.completer import set_autocompleter
|
from .cli.completer import set_autocompleter
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
|
from .services.default_graphs import create_system_graphs, default_text_to_image_graph_id
|
||||||
from .services.events import EventServiceBase
|
from .services.events import EventServiceBase
|
||||||
from .services.graph import (
|
from .services.graph import (
|
||||||
Edge,
|
Edge,
|
||||||
@ -49,19 +50,15 @@ from .services.image_file_storage import DiskImageFileStorage
|
|||||||
from .services.invocation_queue import MemoryInvocationQueue
|
from .services.invocation_queue import MemoryInvocationQueue
|
||||||
from .services.invocation_services import InvocationServices
|
from .services.invocation_services import InvocationServices
|
||||||
from .services.invoker import Invoker
|
from .services.invoker import Invoker
|
||||||
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
from .services.model_manager_service import ModelManagerService
|
from .services.model_manager_service import ModelManagerService
|
||||||
from .services.processor import DefaultInvocationProcessor
|
from .services.processor import DefaultInvocationProcessor
|
||||||
from .services.sqlite import SqliteItemStorage
|
from .services.sqlite import SqliteItemStorage
|
||||||
|
|
||||||
import torch
|
|
||||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
|
||||||
|
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||||
|
|
||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
|
||||||
config.parse_args()
|
|
||||||
logger = InvokeAILogger().getLogger(config=config)
|
logger = InvokeAILogger().getLogger(config=config)
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,6 +28,8 @@ from pydantic.fields import Undefined, ModelField
|
|||||||
from pydantic.typing import NoArgAnyCallable
|
from pydantic.typing import NoArgAnyCallable
|
||||||
import semver
|
import semver
|
||||||
|
|
||||||
|
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
|
|
||||||
@ -470,6 +472,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_all_subclasses(cls):
|
def get_all_subclasses(cls):
|
||||||
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
subclasses = []
|
subclasses = []
|
||||||
toprocess = [cls]
|
toprocess = [cls]
|
||||||
while len(toprocess) > 0:
|
while len(toprocess) > 0:
|
||||||
@ -477,7 +480,23 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
next_subclasses = next.__subclasses__()
|
next_subclasses = next.__subclasses__()
|
||||||
subclasses.extend(next_subclasses)
|
subclasses.extend(next_subclasses)
|
||||||
toprocess.extend(next_subclasses)
|
toprocess.extend(next_subclasses)
|
||||||
return subclasses
|
allowed_invocations = []
|
||||||
|
for sc in subclasses:
|
||||||
|
is_in_allowlist = (
|
||||||
|
sc.__fields__.get("type").default in app_config.allow_nodes
|
||||||
|
if isinstance(app_config.allow_nodes, list)
|
||||||
|
else True
|
||||||
|
)
|
||||||
|
|
||||||
|
is_in_denylist = (
|
||||||
|
sc.__fields__.get("type").default in app_config.deny_nodes
|
||||||
|
if isinstance(app_config.deny_nodes, list)
|
||||||
|
else False
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_in_allowlist and not is_in_denylist:
|
||||||
|
allowed_invocations.append(sc)
|
||||||
|
return allowed_invocations
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_invocations(cls):
|
def get_invocations(cls):
|
||||||
|
@ -42,7 +42,9 @@ class InvokeAISettings(BaseSettings):
|
|||||||
|
|
||||||
def parse_args(self, argv: list = sys.argv[1:]):
|
def parse_args(self, argv: list = sys.argv[1:]):
|
||||||
parser = self.get_parser()
|
parser = self.get_parser()
|
||||||
opt = parser.parse_args(argv)
|
opt, unknown_opts = parser.parse_known_args(argv)
|
||||||
|
if len(unknown_opts) > 0:
|
||||||
|
print("Unknown args:", unknown_opts)
|
||||||
for name in self.__fields__:
|
for name in self.__fields__:
|
||||||
if name not in self._excluded():
|
if name not in self._excluded():
|
||||||
value = getattr(opt, name)
|
value = getattr(opt, name)
|
||||||
|
@ -254,6 +254,10 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
attention_slice_size: Literal[tuple(["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8])] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
|
attention_slice_size: Literal[tuple(["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8])] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
|
||||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
||||||
|
|
||||||
|
# NODES
|
||||||
|
allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", category="Nodes")
|
||||||
|
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", category="Nodes")
|
||||||
|
|
||||||
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
||||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
||||||
free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
||||||
|
@ -9,13 +9,17 @@ import { startAppListening } from '..';
|
|||||||
export const addReceivedOpenAPISchemaListener = () => {
|
export const addReceivedOpenAPISchemaListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: receivedOpenAPISchema.fulfilled,
|
actionCreator: receivedOpenAPISchema.fulfilled,
|
||||||
effect: (action, { dispatch }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const log = logger('system');
|
const log = logger('system');
|
||||||
const schemaJSON = action.payload;
|
const schemaJSON = action.payload;
|
||||||
|
|
||||||
log.debug({ schemaJSON }, 'Received OpenAPI schema');
|
log.debug({ schemaJSON }, 'Received OpenAPI schema');
|
||||||
|
const { nodesAllowlist, nodesDenylist } = getState().config;
|
||||||
const nodeTemplates = parseSchema(schemaJSON);
|
const nodeTemplates = parseSchema(
|
||||||
|
schemaJSON,
|
||||||
|
nodesAllowlist,
|
||||||
|
nodesDenylist
|
||||||
|
);
|
||||||
|
|
||||||
log.debug(
|
log.debug(
|
||||||
{ nodeTemplates: parseify(nodeTemplates) },
|
{ nodeTemplates: parseify(nodeTemplates) },
|
||||||
|
@ -50,6 +50,8 @@ export type AppConfig = {
|
|||||||
disabledFeatures: AppFeature[];
|
disabledFeatures: AppFeature[];
|
||||||
disabledSDFeatures: SDFeature[];
|
disabledSDFeatures: SDFeature[];
|
||||||
canRestoreDeletedImagesFromBin: boolean;
|
canRestoreDeletedImagesFromBin: boolean;
|
||||||
|
nodesAllowlist: string[] | undefined;
|
||||||
|
nodesDenylist: string[] | undefined;
|
||||||
sd: {
|
sd: {
|
||||||
defaultModel?: string;
|
defaultModel?: string;
|
||||||
disabledControlNetModels: string[];
|
disabledControlNetModels: string[];
|
||||||
|
@ -60,11 +60,23 @@ const isNotInDenylist = (schema: InvocationSchemaObject) =>
|
|||||||
!invocationDenylist.includes(schema.properties.type.default);
|
!invocationDenylist.includes(schema.properties.type.default);
|
||||||
|
|
||||||
export const parseSchema = (
|
export const parseSchema = (
|
||||||
openAPI: OpenAPIV3.Document
|
openAPI: OpenAPIV3.Document,
|
||||||
|
nodesAllowlistExtra: string[] | undefined = undefined,
|
||||||
|
nodesDenylistExtra: string[] | undefined = undefined
|
||||||
): Record<string, InvocationTemplate> => {
|
): Record<string, InvocationTemplate> => {
|
||||||
const filteredSchemas = Object.values(openAPI.components?.schemas ?? {})
|
const filteredSchemas = Object.values(openAPI.components?.schemas ?? {})
|
||||||
.filter(isInvocationSchemaObject)
|
.filter(isInvocationSchemaObject)
|
||||||
.filter(isNotInDenylist);
|
.filter(isNotInDenylist)
|
||||||
|
.filter((schema) =>
|
||||||
|
nodesAllowlistExtra
|
||||||
|
? nodesAllowlistExtra.includes(schema.properties.type.default)
|
||||||
|
: true
|
||||||
|
)
|
||||||
|
.filter((schema) =>
|
||||||
|
nodesDenylistExtra
|
||||||
|
? !nodesDenylistExtra.includes(schema.properties.type.default)
|
||||||
|
: true
|
||||||
|
);
|
||||||
|
|
||||||
const invocations = filteredSchemas.reduce<
|
const invocations = filteredSchemas.reduce<
|
||||||
Record<string, InvocationTemplate>
|
Record<string, InvocationTemplate>
|
||||||
|
@ -15,6 +15,8 @@ export const initialConfigState: AppConfig = {
|
|||||||
'perlinNoise',
|
'perlinNoise',
|
||||||
'noiseThreshold',
|
'noiseThreshold',
|
||||||
],
|
],
|
||||||
|
nodesAllowlist: undefined,
|
||||||
|
nodesDenylist: undefined,
|
||||||
canRestoreDeletedImagesFromBin: true,
|
canRestoreDeletedImagesFromBin: true,
|
||||||
sd: {
|
sd: {
|
||||||
disabledControlNetModels: [],
|
disabledControlNetModels: [],
|
||||||
|
@ -3,7 +3,7 @@ import { PayloadAction, createSlice, isAnyOf } from '@reduxjs/toolkit';
|
|||||||
import { InvokeLogLevel } from 'app/logging/logger';
|
import { InvokeLogLevel } from 'app/logging/logger';
|
||||||
import { userInvoked } from 'app/store/actions';
|
import { userInvoked } from 'app/store/actions';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { get, startCase, upperFirst } from 'lodash-es';
|
import { get, startCase, truncate, upperFirst } from 'lodash-es';
|
||||||
import { LogLevelName } from 'roarr';
|
import { LogLevelName } from 'roarr';
|
||||||
import {
|
import {
|
||||||
isAnySessionRejected,
|
isAnySessionRejected,
|
||||||
@ -357,10 +357,13 @@ export const systemSlice = createSlice({
|
|||||||
result.data.error.detail.map((e) => {
|
result.data.error.detail.map((e) => {
|
||||||
state.toastQueue.push(
|
state.toastQueue.push(
|
||||||
makeToast({
|
makeToast({
|
||||||
title: upperFirst(e.msg),
|
title: truncate(upperFirst(e.msg), { length: 128 }),
|
||||||
status: 'error',
|
status: 'error',
|
||||||
description: `Path:
|
description: truncate(
|
||||||
${e.loc.slice(3).join('.')}`,
|
`Path:
|
||||||
|
${e.loc.join('.')}`,
|
||||||
|
{ length: 128 }
|
||||||
|
),
|
||||||
duration,
|
duration,
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
@ -375,7 +378,10 @@ export const systemSlice = createSlice({
|
|||||||
makeToast({
|
makeToast({
|
||||||
title: t('toast.serverError'),
|
title: t('toast.serverError'),
|
||||||
status: 'error',
|
status: 'error',
|
||||||
description: get(errorDescription, 'detail', 'Unknown Error'),
|
description: truncate(
|
||||||
|
get(errorDescription, 'detail', 'Unknown Error'),
|
||||||
|
{ length: 128 }
|
||||||
|
),
|
||||||
duration,
|
duration,
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pathlib import Path
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
@ -147,3 +148,58 @@ def test_type_coercion(patch_rootdir):
|
|||||||
conf.parse_args(argv=["--root=/tmp/foobar"])
|
conf.parse_args(argv=["--root=/tmp/foobar"])
|
||||||
assert conf.root == Path("/tmp/different")
|
assert conf.root == Path("/tmp/different")
|
||||||
assert isinstance(conf.root, Path)
|
assert isinstance(conf.root, Path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.xfail(
|
||||||
|
reason="""
|
||||||
|
This test fails when run as part of the full test suite.
|
||||||
|
|
||||||
|
This test needs to deny nodes from being included in the InvocationsUnion by providing
|
||||||
|
an app configuration as a test fixture. Pytest executes all test files before running
|
||||||
|
tests, so the app configuration is already initialized by the time this test runs, and
|
||||||
|
the InvocationUnion is already created and the denied nodes are not omitted from it.
|
||||||
|
|
||||||
|
This test passes when `test_config.py` is tested in isolation.
|
||||||
|
|
||||||
|
Perhaps a solution would be to call `InvokeAIAppConfig.get_config().parse_args()` in
|
||||||
|
other test files?
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
def test_deny_nodes(patch_rootdir):
|
||||||
|
# Allow integer, string and float, but explicitly deny float
|
||||||
|
allow_deny_nodes_conf = OmegaConf.create(
|
||||||
|
"""
|
||||||
|
InvokeAI:
|
||||||
|
Nodes:
|
||||||
|
allow_nodes:
|
||||||
|
- integer
|
||||||
|
- string
|
||||||
|
- float
|
||||||
|
deny_nodes:
|
||||||
|
- float
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
# must parse config before importing Graph, so its nodes union uses the config
|
||||||
|
conf = InvokeAIAppConfig().get_config()
|
||||||
|
conf.parse_args(conf=allow_deny_nodes_conf, argv=[])
|
||||||
|
from invokeai.app.services.graph import Graph
|
||||||
|
|
||||||
|
# confirm graph validation fails when using denied node
|
||||||
|
Graph(nodes={"1": {"id": "1", "type": "integer"}})
|
||||||
|
Graph(nodes={"1": {"id": "1", "type": "string"}})
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Graph(nodes={"1": {"id": "1", "type": "float"}})
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
|
|
||||||
|
# confirm invocations union will not have denied nodes
|
||||||
|
all_invocations = BaseInvocation.get_invocations()
|
||||||
|
|
||||||
|
has_integer = len([i for i in all_invocations if i.__fields__.get("type").default == "integer"]) == 1
|
||||||
|
has_string = len([i for i in all_invocations if i.__fields__.get("type").default == "string"]) == 1
|
||||||
|
has_float = len([i for i in all_invocations if i.__fields__.get("type").default == "float"]) == 1
|
||||||
|
|
||||||
|
assert has_integer
|
||||||
|
assert has_string
|
||||||
|
assert not has_float
|
||||||
|
Loading…
Reference in New Issue
Block a user