Merge branch 'main' into main

This commit is contained in:
Millun Atluri 2023-09-09 12:31:26 +10:00 committed by GitHub
commit abc50ce88b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 214 additions and 108 deletions

View File

@ -1,10 +1,21 @@
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team # Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from .services.config import InvokeAIAppConfig
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
# which are imported/used before parse_args() is called will get the default config values instead of the
# values from the command line or config file.
app_config = InvokeAIAppConfig.get_config()
app_config.parse_args()
if True: # hack to make flake8 happy with imports coming after setting up the config
import asyncio import asyncio
import logging import logging
import mimetypes
import socket import socket
from inspect import signature from inspect import signature
from pathlib import Path from pathlib import Path
import torch
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -15,31 +26,22 @@ from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.schema import schema from pydantic.schema import schema
from .services.config import InvokeAIAppConfig
from ..backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
import invokeai.frontend.web as web_dir
import mimetypes
from .api.dependencies import ApiDependencies
from .api.routers import sessions, models, images, boards, board_images, app_info
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
import torch
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
from invokeai.version.invokeai_version import __version__
from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies
from .api.routers import app_info, board_images, boards, images, models, sessions
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import) import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
app_config = InvokeAIAppConfig.get_config()
app_config.parse_args()
logger = InvokeAILogger.getLogger(config=app_config) logger = InvokeAILogger.getLogger(config=app_config)
# fix for windows mimetypes registry entries being borked # fix for windows mimetypes registry entries being borked

View File

@ -1,41 +1,42 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from .services.config import InvokeAIAppConfig
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
# which are imported/used before parse_args() is called will get the default config values instead of the
# values from the command line or config file.
config = InvokeAIAppConfig.get_config()
config.parse_args()
if True: # hack to make flake8 happy with imports coming after setting up the config
import argparse import argparse
import re import re
import shlex import shlex
import sys import sys
import time import time
from typing import Union, get_type_hints, Optional from typing import Optional, Union, get_type_hints
import torch
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from pydantic.fields import Field from pydantic.fields import Field
# This should come early so that the logger can pick up its configuration options import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
from .services.config import InvokeAIAppConfig from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage
from invokeai.backend.util.logging import InvokeAILogger from invokeai.app.services.board_images import BoardImagesService, BoardImagesServiceDependencies
from invokeai.version.invokeai_version import __version__
from invokeai.app.services.board_image_record_storage import (
SqliteBoardImageRecordStorage,
)
from invokeai.app.services.board_images import (
BoardImagesService,
BoardImagesServiceDependencies,
)
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.invocation_stats import InvocationStatsService
from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService from invokeai.app.services.urls import LocalUrlService
from invokeai.app.services.invocation_stats import InvocationStatsService from invokeai.backend.util.logging import InvokeAILogger
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs from invokeai.version.invokeai_version import __version__
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from .cli.commands import BaseCommand, CliContext, ExitCli, SortedHelpFormatter, add_graph_parsers, add_parsers from .cli.commands import BaseCommand, CliContext, ExitCli, SortedHelpFormatter, add_graph_parsers, add_parsers
from .cli.completer import set_autocompleter from .cli.completer import set_autocompleter
from .invocations.baseinvocation import BaseInvocation from .invocations.baseinvocation import BaseInvocation
from .services.default_graphs import create_system_graphs, default_text_to_image_graph_id
from .services.events import EventServiceBase from .services.events import EventServiceBase
from .services.graph import ( from .services.graph import (
Edge, Edge,
@ -49,19 +50,15 @@ from .services.image_file_storage import DiskImageFileStorage
from .services.invocation_queue import MemoryInvocationQueue from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices from .services.invocation_services import InvocationServices
from .services.invoker import Invoker from .services.invoker import Invoker
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from .services.model_manager_service import ModelManagerService from .services.model_manager_service import ModelManagerService
from .services.processor import DefaultInvocationProcessor from .services.processor import DefaultInvocationProcessor
from .services.sqlite import SqliteItemStorage from .services.sqlite import SqliteItemStorage
import torch
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import) import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger().getLogger(config=config) logger = InvokeAILogger().getLogger(config=config)

View File

@ -28,6 +28,8 @@ from pydantic.fields import Undefined, ModelField
from pydantic.typing import NoArgAnyCallable from pydantic.typing import NoArgAnyCallable
import semver import semver
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
if TYPE_CHECKING: if TYPE_CHECKING:
from ..services.invocation_services import InvocationServices from ..services.invocation_services import InvocationServices
@ -470,6 +472,7 @@ class BaseInvocation(ABC, BaseModel):
@classmethod @classmethod
def get_all_subclasses(cls): def get_all_subclasses(cls):
app_config = InvokeAIAppConfig.get_config()
subclasses = [] subclasses = []
toprocess = [cls] toprocess = [cls]
while len(toprocess) > 0: while len(toprocess) > 0:
@ -477,7 +480,23 @@ class BaseInvocation(ABC, BaseModel):
next_subclasses = next.__subclasses__() next_subclasses = next.__subclasses__()
subclasses.extend(next_subclasses) subclasses.extend(next_subclasses)
toprocess.extend(next_subclasses) toprocess.extend(next_subclasses)
return subclasses allowed_invocations = []
for sc in subclasses:
is_in_allowlist = (
sc.__fields__.get("type").default in app_config.allow_nodes
if isinstance(app_config.allow_nodes, list)
else True
)
is_in_denylist = (
sc.__fields__.get("type").default in app_config.deny_nodes
if isinstance(app_config.deny_nodes, list)
else False
)
if is_in_allowlist and not is_in_denylist:
allowed_invocations.append(sc)
return allowed_invocations
@classmethod @classmethod
def get_invocations(cls): def get_invocations(cls):

View File

@ -42,7 +42,9 @@ class InvokeAISettings(BaseSettings):
def parse_args(self, argv: list = sys.argv[1:]): def parse_args(self, argv: list = sys.argv[1:]):
parser = self.get_parser() parser = self.get_parser()
opt = parser.parse_args(argv) opt, unknown_opts = parser.parse_known_args(argv)
if len(unknown_opts) > 0:
print("Unknown args:", unknown_opts)
for name in self.__fields__: for name in self.__fields__:
if name not in self._excluded(): if name not in self._excluded():
value = getattr(opt, name) value = getattr(opt, name)

View File

@ -254,6 +254,10 @@ class InvokeAIAppConfig(InvokeAISettings):
attention_slice_size: Literal[tuple(["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8])] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", ) attention_slice_size: Literal[tuple(["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8])] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",) force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
# NODES
allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", category="Nodes")
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", category="Nodes")
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES # DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance') always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", category='Memory/Performance') free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", category='Memory/Performance')

View File

@ -9,13 +9,17 @@ import { startAppListening } from '..';
export const addReceivedOpenAPISchemaListener = () => { export const addReceivedOpenAPISchemaListener = () => {
startAppListening({ startAppListening({
actionCreator: receivedOpenAPISchema.fulfilled, actionCreator: receivedOpenAPISchema.fulfilled,
effect: (action, { dispatch }) => { effect: (action, { dispatch, getState }) => {
const log = logger('system'); const log = logger('system');
const schemaJSON = action.payload; const schemaJSON = action.payload;
log.debug({ schemaJSON }, 'Received OpenAPI schema'); log.debug({ schemaJSON }, 'Received OpenAPI schema');
const { nodesAllowlist, nodesDenylist } = getState().config;
const nodeTemplates = parseSchema(schemaJSON); const nodeTemplates = parseSchema(
schemaJSON,
nodesAllowlist,
nodesDenylist
);
log.debug( log.debug(
{ nodeTemplates: parseify(nodeTemplates) }, { nodeTemplates: parseify(nodeTemplates) },

View File

@ -50,6 +50,8 @@ export type AppConfig = {
disabledFeatures: AppFeature[]; disabledFeatures: AppFeature[];
disabledSDFeatures: SDFeature[]; disabledSDFeatures: SDFeature[];
canRestoreDeletedImagesFromBin: boolean; canRestoreDeletedImagesFromBin: boolean;
nodesAllowlist: string[] | undefined;
nodesDenylist: string[] | undefined;
sd: { sd: {
defaultModel?: string; defaultModel?: string;
disabledControlNetModels: string[]; disabledControlNetModels: string[];

View File

@ -60,11 +60,23 @@ const isNotInDenylist = (schema: InvocationSchemaObject) =>
!invocationDenylist.includes(schema.properties.type.default); !invocationDenylist.includes(schema.properties.type.default);
export const parseSchema = ( export const parseSchema = (
openAPI: OpenAPIV3.Document openAPI: OpenAPIV3.Document,
nodesAllowlistExtra: string[] | undefined = undefined,
nodesDenylistExtra: string[] | undefined = undefined
): Record<string, InvocationTemplate> => { ): Record<string, InvocationTemplate> => {
const filteredSchemas = Object.values(openAPI.components?.schemas ?? {}) const filteredSchemas = Object.values(openAPI.components?.schemas ?? {})
.filter(isInvocationSchemaObject) .filter(isInvocationSchemaObject)
.filter(isNotInDenylist); .filter(isNotInDenylist)
.filter((schema) =>
nodesAllowlistExtra
? nodesAllowlistExtra.includes(schema.properties.type.default)
: true
)
.filter((schema) =>
nodesDenylistExtra
? !nodesDenylistExtra.includes(schema.properties.type.default)
: true
);
const invocations = filteredSchemas.reduce< const invocations = filteredSchemas.reduce<
Record<string, InvocationTemplate> Record<string, InvocationTemplate>

View File

@ -15,6 +15,8 @@ export const initialConfigState: AppConfig = {
'perlinNoise', 'perlinNoise',
'noiseThreshold', 'noiseThreshold',
], ],
nodesAllowlist: undefined,
nodesDenylist: undefined,
canRestoreDeletedImagesFromBin: true, canRestoreDeletedImagesFromBin: true,
sd: { sd: {
disabledControlNetModels: [], disabledControlNetModels: [],

View File

@ -3,7 +3,7 @@ import { PayloadAction, createSlice, isAnyOf } from '@reduxjs/toolkit';
import { InvokeLogLevel } from 'app/logging/logger'; import { InvokeLogLevel } from 'app/logging/logger';
import { userInvoked } from 'app/store/actions'; import { userInvoked } from 'app/store/actions';
import { t } from 'i18next'; import { t } from 'i18next';
import { get, startCase, upperFirst } from 'lodash-es'; import { get, startCase, truncate, upperFirst } from 'lodash-es';
import { LogLevelName } from 'roarr'; import { LogLevelName } from 'roarr';
import { import {
isAnySessionRejected, isAnySessionRejected,
@ -357,10 +357,13 @@ export const systemSlice = createSlice({
result.data.error.detail.map((e) => { result.data.error.detail.map((e) => {
state.toastQueue.push( state.toastQueue.push(
makeToast({ makeToast({
title: upperFirst(e.msg), title: truncate(upperFirst(e.msg), { length: 128 }),
status: 'error', status: 'error',
description: `Path: description: truncate(
${e.loc.slice(3).join('.')}`, `Path:
${e.loc.join('.')}`,
{ length: 128 }
),
duration, duration,
}) })
); );
@ -375,7 +378,10 @@ export const systemSlice = createSlice({
makeToast({ makeToast({
title: t('toast.serverError'), title: t('toast.serverError'),
status: 'error', status: 'error',
description: get(errorDescription, 'detail', 'Unknown Error'), description: truncate(
get(errorDescription, 'detail', 'Unknown Error'),
{ length: 128 }
),
duration, duration,
}) })
); );

View File

@ -1,9 +1,10 @@
import os import os
from pathlib import Path
from typing import Any from typing import Any
import pytest import pytest
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pathlib import Path from pydantic import ValidationError
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
@ -147,3 +148,58 @@ def test_type_coercion(patch_rootdir):
conf.parse_args(argv=["--root=/tmp/foobar"]) conf.parse_args(argv=["--root=/tmp/foobar"])
assert conf.root == Path("/tmp/different") assert conf.root == Path("/tmp/different")
assert isinstance(conf.root, Path) assert isinstance(conf.root, Path)
@pytest.mark.xfail(
reason="""
This test fails when run as part of the full test suite.
This test needs to deny nodes from being included in the InvocationsUnion by providing
an app configuration as a test fixture. Pytest executes all test files before running
tests, so the app configuration is already initialized by the time this test runs, and
the InvocationUnion is already created and the denied nodes are not omitted from it.
This test passes when `test_config.py` is tested in isolation.
Perhaps a solution would be to call `InvokeAIAppConfig.get_config().parse_args()` in
other test files?
"""
)
def test_deny_nodes(patch_rootdir):
# Allow integer, string and float, but explicitly deny float
allow_deny_nodes_conf = OmegaConf.create(
"""
InvokeAI:
Nodes:
allow_nodes:
- integer
- string
- float
deny_nodes:
- float
"""
)
# must parse config before importing Graph, so its nodes union uses the config
conf = InvokeAIAppConfig().get_config()
conf.parse_args(conf=allow_deny_nodes_conf, argv=[])
from invokeai.app.services.graph import Graph
# confirm graph validation fails when using denied node
Graph(nodes={"1": {"id": "1", "type": "integer"}})
Graph(nodes={"1": {"id": "1", "type": "string"}})
with pytest.raises(ValidationError):
Graph(nodes={"1": {"id": "1", "type": "float"}})
from invokeai.app.invocations.baseinvocation import BaseInvocation
# confirm invocations union will not have denied nodes
all_invocations = BaseInvocation.get_invocations()
has_integer = len([i for i in all_invocations if i.__fields__.get("type").default == "integer"]) == 1
has_string = len([i for i in all_invocations if i.__fields__.get("type").default == "string"]) == 1
has_float = len([i for i in all_invocations if i.__fields__.get("type").default == "float"]) == 1
assert has_integer
assert has_string
assert not has_float