Merge branch 'main' into main

This commit is contained in:
Millun Atluri 2023-09-09 12:31:26 +10:00 committed by GitHub
commit abc50ce88b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 214 additions and 108 deletions

View File

@ -1,45 +1,47 @@
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
import asyncio
import logging
import socket
from inspect import signature
from pathlib import Path
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.staticfiles import StaticFiles
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.schema import schema
from .services.config import InvokeAIAppConfig
from ..backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
# which are imported/used before parse_args() is called will get the default config values instead of the
# values from the command line or config file.
app_config = InvokeAIAppConfig.get_config()
app_config.parse_args()
import invokeai.frontend.web as web_dir
import mimetypes
if True: # hack to make flake8 happy with imports coming after setting up the config
import asyncio
import logging
import mimetypes
import socket
from inspect import signature
from pathlib import Path
from .api.dependencies import ApiDependencies
from .api.routers import sessions, models, images, boards, board_images, app_info
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
import torch
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.staticfiles import StaticFiles
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.schema import schema
import torch
# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
from invokeai.version.invokeai_version import __version__
# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies
from .api.routers import app_info, board_images, boards, images, models, sessions
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
if torch.backends.mps.is_available():
if torch.backends.mps.is_available():
# noinspection PyUnresolvedReferences
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
app_config = InvokeAIAppConfig.get_config()
app_config.parse_args()
logger = InvokeAILogger.getLogger(config=app_config)
# fix for windows mimetypes registry entries being borked

View File

@ -1,67 +1,64 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import argparse
import re
import shlex
import sys
import time
from typing import Union, get_type_hints, Optional
from pydantic import BaseModel, ValidationError
from pydantic.fields import Field
# This should come early so that the logger can pick up its configuration options
from .services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
# which are imported/used before parse_args() is called will get the default config values instead of the
# values from the command line or config file.
config = InvokeAIAppConfig.get_config()
config.parse_args()
from invokeai.app.services.board_image_record_storage import (
SqliteBoardImageRecordStorage,
)
from invokeai.app.services.board_images import (
BoardImagesService,
BoardImagesServiceDependencies,
)
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
from invokeai.app.services.invocation_stats import InvocationStatsService
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
if True: # hack to make flake8 happy with imports coming after setting up the config
import argparse
import re
import shlex
import sys
import time
from typing import Optional, Union, get_type_hints
from .cli.commands import BaseCommand, CliContext, ExitCli, SortedHelpFormatter, add_graph_parsers, add_parsers
from .cli.completer import set_autocompleter
from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase
from .services.graph import (
import torch
from pydantic import BaseModel, ValidationError
from pydantic.fields import Field
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage
from invokeai.app.services.board_images import BoardImagesService, BoardImagesServiceDependencies
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.invocation_stats import InvocationStatsService
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
from .cli.commands import BaseCommand, CliContext, ExitCli, SortedHelpFormatter, add_graph_parsers, add_parsers
from .cli.completer import set_autocompleter
from .invocations.baseinvocation import BaseInvocation
from .services.default_graphs import create_system_graphs, default_text_to_image_graph_id
from .services.events import EventServiceBase
from .services.graph import (
Edge,
EdgeConnection,
GraphExecutionState,
GraphInvocation,
LibraryGraph,
are_connection_types_compatible,
)
from .services.image_file_storage import DiskImageFileStorage
from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices
from .services.invoker import Invoker
from .services.model_manager_service import ModelManagerService
from .services.processor import DefaultInvocationProcessor
from .services.sqlite import SqliteItemStorage
)
from .services.image_file_storage import DiskImageFileStorage
from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices
from .services.invoker import Invoker
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from .services.model_manager_service import ModelManagerService
from .services.processor import DefaultInvocationProcessor
from .services.sqlite import SqliteItemStorage
import torch
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
if torch.backends.mps.is_available():
if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger().getLogger(config=config)

View File

@ -28,6 +28,8 @@ from pydantic.fields import Undefined, ModelField
from pydantic.typing import NoArgAnyCallable
import semver
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
if TYPE_CHECKING:
from ..services.invocation_services import InvocationServices
@ -470,6 +472,7 @@ class BaseInvocation(ABC, BaseModel):
@classmethod
def get_all_subclasses(cls):
app_config = InvokeAIAppConfig.get_config()
subclasses = []
toprocess = [cls]
while len(toprocess) > 0:
@ -477,7 +480,23 @@ class BaseInvocation(ABC, BaseModel):
next_subclasses = next.__subclasses__()
subclasses.extend(next_subclasses)
toprocess.extend(next_subclasses)
return subclasses
allowed_invocations = []
for sc in subclasses:
is_in_allowlist = (
sc.__fields__.get("type").default in app_config.allow_nodes
if isinstance(app_config.allow_nodes, list)
else True
)
is_in_denylist = (
sc.__fields__.get("type").default in app_config.deny_nodes
if isinstance(app_config.deny_nodes, list)
else False
)
if is_in_allowlist and not is_in_denylist:
allowed_invocations.append(sc)
return allowed_invocations
@classmethod
def get_invocations(cls):

View File

@ -42,7 +42,9 @@ class InvokeAISettings(BaseSettings):
def parse_args(self, argv: list = sys.argv[1:]):
parser = self.get_parser()
opt = parser.parse_args(argv)
opt, unknown_opts = parser.parse_known_args(argv)
if len(unknown_opts) > 0:
print("Unknown args:", unknown_opts)
for name in self.__fields__:
if name not in self._excluded():
value = getattr(opt, name)

View File

@ -254,6 +254,10 @@ class InvokeAIAppConfig(InvokeAISettings):
attention_slice_size: Literal[tuple(["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8])] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
# NODES
allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", category="Nodes")
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", category="Nodes")
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", category='Memory/Performance')

View File

@ -9,13 +9,17 @@ import { startAppListening } from '..';
export const addReceivedOpenAPISchemaListener = () => {
startAppListening({
actionCreator: receivedOpenAPISchema.fulfilled,
effect: (action, { dispatch }) => {
effect: (action, { dispatch, getState }) => {
const log = logger('system');
const schemaJSON = action.payload;
log.debug({ schemaJSON }, 'Received OpenAPI schema');
const nodeTemplates = parseSchema(schemaJSON);
const { nodesAllowlist, nodesDenylist } = getState().config;
const nodeTemplates = parseSchema(
schemaJSON,
nodesAllowlist,
nodesDenylist
);
log.debug(
{ nodeTemplates: parseify(nodeTemplates) },

View File

@ -50,6 +50,8 @@ export type AppConfig = {
disabledFeatures: AppFeature[];
disabledSDFeatures: SDFeature[];
canRestoreDeletedImagesFromBin: boolean;
nodesAllowlist: string[] | undefined;
nodesDenylist: string[] | undefined;
sd: {
defaultModel?: string;
disabledControlNetModels: string[];

View File

@ -60,11 +60,23 @@ const isNotInDenylist = (schema: InvocationSchemaObject) =>
!invocationDenylist.includes(schema.properties.type.default);
export const parseSchema = (
openAPI: OpenAPIV3.Document
openAPI: OpenAPIV3.Document,
nodesAllowlistExtra: string[] | undefined = undefined,
nodesDenylistExtra: string[] | undefined = undefined
): Record<string, InvocationTemplate> => {
const filteredSchemas = Object.values(openAPI.components?.schemas ?? {})
.filter(isInvocationSchemaObject)
.filter(isNotInDenylist);
.filter(isNotInDenylist)
.filter((schema) =>
nodesAllowlistExtra
? nodesAllowlistExtra.includes(schema.properties.type.default)
: true
)
.filter((schema) =>
nodesDenylistExtra
? !nodesDenylistExtra.includes(schema.properties.type.default)
: true
);
const invocations = filteredSchemas.reduce<
Record<string, InvocationTemplate>

View File

@ -15,6 +15,8 @@ export const initialConfigState: AppConfig = {
'perlinNoise',
'noiseThreshold',
],
nodesAllowlist: undefined,
nodesDenylist: undefined,
canRestoreDeletedImagesFromBin: true,
sd: {
disabledControlNetModels: [],

View File

@ -3,7 +3,7 @@ import { PayloadAction, createSlice, isAnyOf } from '@reduxjs/toolkit';
import { InvokeLogLevel } from 'app/logging/logger';
import { userInvoked } from 'app/store/actions';
import { t } from 'i18next';
import { get, startCase, upperFirst } from 'lodash-es';
import { get, startCase, truncate, upperFirst } from 'lodash-es';
import { LogLevelName } from 'roarr';
import {
isAnySessionRejected,
@ -357,10 +357,13 @@ export const systemSlice = createSlice({
result.data.error.detail.map((e) => {
state.toastQueue.push(
makeToast({
title: upperFirst(e.msg),
title: truncate(upperFirst(e.msg), { length: 128 }),
status: 'error',
description: `Path:
${e.loc.slice(3).join('.')}`,
description: truncate(
`Path:
${e.loc.join('.')}`,
{ length: 128 }
),
duration,
})
);
@ -375,7 +378,10 @@ export const systemSlice = createSlice({
makeToast({
title: t('toast.serverError'),
status: 'error',
description: get(errorDescription, 'detail', 'Unknown Error'),
description: truncate(
get(errorDescription, 'detail', 'Unknown Error'),
{ length: 128 }
),
duration,
})
);

View File

@ -1,9 +1,10 @@
import os
from pathlib import Path
from typing import Any
import pytest
from omegaconf import OmegaConf
from pathlib import Path
from pydantic import ValidationError
from invokeai.app.services.config import InvokeAIAppConfig
@ -147,3 +148,58 @@ def test_type_coercion(patch_rootdir):
conf.parse_args(argv=["--root=/tmp/foobar"])
assert conf.root == Path("/tmp/different")
assert isinstance(conf.root, Path)
@pytest.mark.xfail(
reason="""
This test fails when run as part of the full test suite.
This test needs to deny nodes from being included in the InvocationsUnion by providing
an app configuration as a test fixture. Pytest executes all test files before running
tests, so the app configuration is already initialized by the time this test runs, and
the InvocationUnion is already created and the denied nodes are not omitted from it.
This test passes when `test_config.py` is tested in isolation.
Perhaps a solution would be to call `InvokeAIAppConfig.get_config().parse_args()` in
other test files?
"""
)
def test_deny_nodes(patch_rootdir):
# Allow integer, string and float, but explicitly deny float
allow_deny_nodes_conf = OmegaConf.create(
"""
InvokeAI:
Nodes:
allow_nodes:
- integer
- string
- float
deny_nodes:
- float
"""
)
# must parse config before importing Graph, so its nodes union uses the config
conf = InvokeAIAppConfig().get_config()
conf.parse_args(conf=allow_deny_nodes_conf, argv=[])
from invokeai.app.services.graph import Graph
# confirm graph validation fails when using denied node
Graph(nodes={"1": {"id": "1", "type": "integer"}})
Graph(nodes={"1": {"id": "1", "type": "string"}})
with pytest.raises(ValidationError):
Graph(nodes={"1": {"id": "1", "type": "float"}})
from invokeai.app.invocations.baseinvocation import BaseInvocation
# confirm invocations union will not have denied nodes
all_invocations = BaseInvocation.get_invocations()
has_integer = len([i for i in all_invocations if i.__fields__.get("type").default == "integer"]) == 1
has_string = len([i for i in all_invocations if i.__fields__.get("type").default == "string"]) == 1
has_float = len([i for i in all_invocations if i.__fields__.get("type").default == "float"]) == 1
assert has_integer
assert has_string
assert not has_float