Compare commits

...

25 Commits

Author SHA1 Message Date
8b281b792c implement changes requested by @kyle0654 for PR #3221
- Add argparse groups to command-line switches.
- Do not let the `link` and `link_node` arguments leak into invocation objects.
- Remove `type` field from `LibraryGraph`
2023-05-01 18:34:01 -04:00
447f590fa4 Merge branch 'main' into lstein/enhance/configuration 2023-04-30 23:03:08 -04:00
3f301d3db3 do not pass conf to invocations! This is a class variable 2023-04-28 16:51:42 -04:00
4e21f5f046 removed permission to add extra fields to pydantic objects 2023-04-28 16:34:37 -04:00
56462d62ae improve config support for API
- There is now only a single place where api_app.py parses command-line switches,
  init file, and environment variable to create an InvokeAISettings object.
- This object is used to initialize uvicorn server (port, etc), and then
  stashed into the InvocationServices object for later use.
2023-04-28 16:16:22 -04:00
e67f2d7f6f disable test that won't run properly in CI environment 2023-04-28 10:15:35 -04:00
35ef37e481 Merge branch 'main' into lstein/enhance/configuration 2023-04-28 09:59:44 -04:00
cfb08d97f9 fix INVOKEAI_ROOT not being respected 2023-04-28 09:57:44 -04:00
c3a95bda2f Update invokeai/app/services/config.py
Co-authored-by: Eugene Brodsky <ebr@users.noreply.github.com>
2023-04-28 07:06:58 -04:00
a64bd95df2 Merge branch 'main' into lstein/enhance/configuration 2023-04-28 07:06:05 -04:00
575feb8aee Merge branch 'main' into lstein/enhance/configuration 2023-04-24 22:27:44 -04:00
8318d22d63 fix config test so it works on windows 2023-04-24 22:27:36 -04:00
ca04f08668 Update invokeai/app/services/config.py
Co-authored-by: Eugene Brodsky <ebr@users.noreply.github.com>
2023-04-24 21:38:24 -04:00
5908c75a7d populate Globals to accommodate legacy code 2023-04-23 22:23:31 -04:00
94fac09cb7 revert InvocationOutput to use BaseModel 2023-04-23 13:22:42 -04:00
49dd3d93e7 Merge branch 'main' into lstein/enhance/configuration 2023-04-22 20:19:36 +01:00
2c59164120 pytests working 2023-04-22 20:09:13 +01:00
f6fdd8b805 add unit tests 2023-04-22 19:33:10 +01:00
600258c24c configuration basically functional; graphs not picking up settings 2023-04-22 15:56:17 +01:00
57d032d7e6 add config_management to context; config web settings 2023-04-18 15:33:36 -04:00
196c21b7c9 turn off environment variable case sensitivity 2023-04-17 23:28:31 -04:00
95f6c85a29 merge with main 2023-04-17 22:24:28 -04:00
a0df350608 ok first draft 2023-04-17 22:18:04 -04:00
1815c15ae2 early version of configuration manager
- config manager controls the configuration of both the app as a whole
  (settings like --precision), as well as the settings for individual
  invocations (such as --strength).

- a yaml based configuration file sets the defaults for each invocation's fields

- command-line options are automatically parsed and supersede the config file
  settings:

  command-line-value > config-file-value > pydantic-default-value

Right now the manager isn't hooked into the CLI or API code, so it does
nothing.
2023-04-17 18:03:55 -04:00
629697c4dd support python 3.9 2023-04-16 10:35:29 -07:00
19 changed files with 594 additions and 153 deletions

View File

@ -9,7 +9,8 @@ from ..services.default_graphs import create_system_graphs
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ...backend import Globals
from ...backend.globals import Globals, copy_conf_to_globals
from ..services.config import InvokeAIWebConfig
from ..services.model_manager_initializer import get_model_manager
from ..services.restoration_services import RestorationServices
from ..services.graph import GraphExecutionState, LibraryGraph
@ -45,11 +46,7 @@ class ApiDependencies:
@staticmethod
def initialize(config, event_handler_id: int):
Globals.try_patchmatch = config.patchmatch
Globals.always_use_cpu = config.always_use_cpu
Globals.internet_available = config.internet_available and check_internet()
Globals.disable_xformers = not config.xformers
Globals.ckpt_convert = config.ckpt_convert
copy_conf_to_globals(config)
# TODO: Use a logger
print(f">> Internet connectivity is {Globals.internet_available}")
@ -84,6 +81,7 @@ class ApiDependencies:
),
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config),
configuration=config,
)
create_system_graphs(services.graph_library)

View File

@ -8,10 +8,6 @@ from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as
from pathlib import Path
from ..dependencies import ApiDependencies
from invokeai.backend.globals import Globals, global_converted_ckpts_dir
from invokeai.backend.args import Args
models_router = APIRouter(prefix="/v1/models", tags=["models"])

View File

@ -12,12 +12,11 @@ from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.schema import schema
from ..backend import Args
from .api.dependencies import ApiDependencies
from .api.routers import images, sessions, models
from .api.sockets import SocketIO
from .invocations import *
from .invocations.baseinvocation import BaseInvocation
from .services.config import InvokeAIWebConfig
# Create the app
# TODO: create this all in a method so configuration/etc. can be passed in?
@ -33,30 +32,15 @@ app.add_middleware(
middleware_id=event_handler_id,
)
# Add CORS
# TODO: use configuration for this
origins = []
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
socket_io = SocketIO(app)
config = {}
web_config = {}
# Add startup event to load dependencies
@app.on_event("startup")
async def startup_event():
config = Args()
config.parse_args()
ApiDependencies.initialize(
config=config, event_handler_id=event_handler_id
config=web_config, event_handler_id=event_handler_id
)
@ -146,12 +130,21 @@ def overridden_redoc():
def invoke_api():
# parse command-line settings, environment and the init file
# (this is a module global)
global web_config
web_config = InvokeAIWebConfig()
app.add_middleware(
CORSMiddleware,
allow_origins=web_config.allow_origins,
allow_credentials=web_config.allow_credentials,
allow_methods=web_config.allow_methods,
allow_headers=web_config.allow_headers,
)
# Start our own event loop for eventing usage
# TODO: determine if there's a better way to do this
loop = asyncio.new_event_loop()
config = uvicorn.Config(app=app, host="0.0.0.0", port=9090, loop=loop)
config = uvicorn.Config(app=app, host=web_config.host, port=web_config.port, loop=loop)
# Use access_log to turn off logging
server = uvicorn.Server(config)
loop.run_until_complete(server.serve())

View File

@ -1,91 +1,52 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from abc import ABC, abstractmethod
import argparse
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
from pydantic import BaseModel, Field
from typing import Any, Callable, Iterable, Literal, Union, get_args, get_type_hints
from pydantic import Field
import networkx as nx
import matplotlib.pyplot as plt
from ..invocations.baseinvocation import BaseInvocation
from ..services.config import InvokeAISettings
from ..invocations.image import ImageField
from ..services.graph import GraphExecutionState, LibraryGraph, GraphInvocation, Edge
from ..services.graph import GraphExecutionState, LibraryGraph, Edge
from ..services.invoker import Invoker
def add_field_argument(command_parser, name: str, field, default_override = None):
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
if get_origin(field.type_) == Literal:
allowed_values = get_args(field.type_)
allowed_types = set()
for val in allowed_values:
allowed_types.add(type(val))
allowed_types_list = list(allowed_types)
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
command_parser.add_argument(
f"--{name}",
dest=name,
type=field_type,
default=default,
choices=allowed_values,
help=field.field_info.description,
)
else:
command_parser.add_argument(
f"--{name}",
dest=name,
type=field.type_,
default=default,
help=field.field_info.description,
)
def add_parsers(
subparsers,
commands: list[type],
command_field: str = "type",
exclude_fields: list[str] = ["id", "type"],
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
):
"""Adds parsers for each command to the subparsers"""
# Create subparsers for each command
for command in commands:
hints = get_type_hints(command)
cmd_name = get_args(hints[command_field])[0]
command_parser = subparsers.add_parser(cmd_name, help=command.__doc__)
name = command.cmd_name()
command_parser = subparsers.add_parser(name, help=command.__doc__)
if add_arguments is not None:
add_arguments(command_parser)
# Convert all fields to arguments
fields = command.__fields__ # type: ignore
for name, field in fields.items():
if name in exclude_fields:
continue
add_field_argument(command_parser, name, field)
command.add_parser_arguments(command_parser)
def add_graph_parsers(
subparsers,
graphs: list[LibraryGraph],
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
):
for graph in graphs:
command_parser = subparsers.add_parser(graph.name, help=graph.description)
if add_arguments is not None:
add_arguments(command_parser)
graph.add_parser_arguments(command_parser)
# Add arguments for inputs
for exposed_input in graph.exposed_inputs:
node = graph.graph.get_node(exposed_input.node_path)
field = node.__fields__[exposed_input.field]
default_override = getattr(node, exposed_input.field)
add_field_argument(command_parser, exposed_input.alias, field, default_override)
graph.add_field_argument(command_parser, exposed_input.alias, field, default_override)
class CliContext:
invoker: Invoker
@ -130,7 +91,7 @@ class ExitCli(Exception):
pass
class BaseCommand(ABC, BaseModel):
class BaseCommand(ABC, InvokeAISettings):
"""A CLI command"""
# All commands must include a type name like this:

View File

@ -10,9 +10,10 @@ import shlex
from pathlib import Path
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
from ...backend import ModelManager, Globals
from ...backend import ModelManager
from ..invocations.baseinvocation import BaseInvocation
from .commands import BaseCommand
from ..services.invocation_services import InvocationServices
# singleton object, class variable
completer = None
@ -130,13 +131,13 @@ class Completer(object):
readline.redisplay()
self.linebuffer = None
def set_autocompleter(model_manager: ModelManager) -> Completer:
def set_autocompleter(services: InvocationServices) -> Completer:
global completer
if completer:
return completer
completer = Completer(model_manager)
completer = Completer(services.model_manager)
readline.set_completer(completer.complete)
# pyreadline3 does not have a set_auto_history() method
@ -152,7 +153,7 @@ def set_autocompleter(model_manager: ModelManager) -> Completer:
readline.parse_and_bind("set skip-completed-text on")
readline.parse_and_bind("set show-all-if-ambiguous on")
histfile = Path(Globals.root, ".invoke_history")
histfile = Path(services.configuration.root_dir / ".invoke_history")
try:
readline.read_history_file(histfile)
readline.set_history_length(1000)

View File

@ -19,7 +19,6 @@ from .services.default_graphs import create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..backend import Args
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, get_graph_execution_history
from .cli.completer import set_autocompleter
from .invocations import *
@ -35,7 +34,8 @@ from .services.invocation_services import InvocationServices
from .services.invoker import Invoker
from .services.processor import DefaultInvocationProcessor
from .services.sqlite import SqliteItemStorage
from .services.config import InvokeAIAppConfig
from ..backend.globals import copy_conf_to_globals # temporary workaround for code depending on Globals
class CliCommand(BaseModel):
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
@ -190,24 +190,20 @@ def invoke_all(context: CliContext):
def invoke_cli():
config = Args()
config.parse_args()
config = InvokeAIAppConfig()
copy_conf_to_globals(config) # temporary workaround
model_manager = get_model_manager(config)
# This initializes the autocompleter and returns it.
# Currently nothing is done with the returned Completer
# object, but the object can be used to change autocompletion
# behavior on the fly, if desired.
completer = set_autocompleter(model_manager)
events = EventServiceBase()
output_folder = config.output_path
metadata = PngMetadataService()
output_folder = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../../../outputs")
)
# TODO: build a file/path manager?
db_location = os.path.join(output_folder, "invokeai.db")
@ -226,6 +222,7 @@ def invoke_cli():
),
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config),
configuration=config,
)
system_graphs = create_system_graphs(services.graph_library)
@ -242,6 +239,8 @@ def invoke_cli():
context = CliContext(invoker, session, parser)
set_autocompleter(services)
while True:
try:
cmd_input = input("invoke> ")
@ -285,8 +284,17 @@ def invoke_cli():
command = CliCommand(command = invocation)
context.graph_nodes[invocation.id] = system_graph.id
else:
args["id"] = current_id
command = CliCommand(command=args)
if "id" in args:
args["id"] = args["id"] or current_id
# remove extraneous fields from initialization
exclude = ['link','link_node']
command_args = dict()
for key,value in args.items():
if key not in exclude:
command_args[key]=value
command = CliCommand(command=command_args)
if command is None:
continue

View File

@ -4,10 +4,10 @@ from abc import ABC, abstractmethod
from inspect import signature
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict
from pydantic import BaseModel, Field
from pydantic import BaseModel, BaseSettings, Field
from ..services.invocation_services import InvocationServices
from ..services.config import InvokeAISettings
class InvocationContext:
services: InvocationServices
@ -36,7 +36,7 @@ class BaseInvocationOutput(BaseModel):
return tuple(subclasses)
class BaseInvocation(ABC, BaseModel):
class BaseInvocation(ABC, InvokeAISettings):
"""A node to process inputs and produce outputs.
May use dependency injection in __init__ to receive providers.
"""
@ -101,8 +101,8 @@ class CustomisedSchemaExtra(TypedDict):
ui: UIConfig
class InvocationConfig(BaseModel.Config):
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
class InvocationConfig(BaseSettings.Config):
"""Customizes pydantic's BaseSettings.Config class for use by Invocations.
Provide `schema_extra` a `ui` dict to add hints for generated UIs.

View File

@ -0,0 +1,379 @@
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein)
'''Invokeai configuration system.
Arguments and fields are taken from the pydantic definition of the
model. Defaults can be set by creating a yaml configuration file that
has top-level keys corresponding to an invocation name, a command, or
"globals" for global values such as `xformers_enabled`. Currently
graphs cannot be configured this way, but their constituents can be.
[file: invokeai.yaml]
globals:
nsfw_checker: False
max_loaded_models: 5
txt2img:
steps: 20
scheduler: k_heun
width: 768
img2img:
width: 1024
height: 1024
The default name of the configuration file is `invokeai.yaml`, located
in INVOKEAI_ROOT. You can use any OmegaConf dictionary by passing it
to the config object at initialization time:
omegaconf = OmegaConf.load('/tmp/init.yaml')
conf = InvokeAIAppConfig(conf=omegaconf)
By default, InvokeAIAppConfig will parse the contents of argv at
initialization time. You may pass a list of strings in the optional
`argv` argument to use instead of the system argv:
conf = InvokeAIAppConfig(arg=['--xformers_enabled'])
It is also possible to set a value at initialization time. This value
has highest priority.
conf = InvokeAIAppConfig(xformers_enabled=True)
Any setting can be overwritten by setting an environment variable of
form: "INVOKEAI_<command>_<value>", as in:
export INVOKEAI_txt2img_steps=30
Order of precedence (from highest):
1) initialization options
2) command line options
3) environment variable options
4) config file options
5) pydantic defaults
Typical usage:
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.invocations.generate import TextToImageInvocation
# get global configuration and print its nsfw_checker value
conf = InvokeAIAppConfig()
print(conf.nsfw_checker)
# get the text2image invocation and print its step value
text2image = TextToImageInvocation()
print(text2image.steps)
Computed properties:
The InvokeAIAppConfig object has a series of properties that
resolve paths relative to the runtime root directory. They each return
a Path object:
root_path - path to InvokeAI root
output_path - path to default outputs directory
model_conf_path - path to models.yaml
conf - alias for the above
embedding_path - path to the embeddings directory
lora_path - path to the LoRA directory
'''
import argparse
import os
import sys
from argparse import ArgumentParser
from omegaconf import OmegaConf, DictConfig
from pathlib import Path
from pydantic import BaseSettings, Field, parse_obj_as
from typing import Any, ClassVar, Dict, List, Literal, Union, get_origin, get_type_hints, get_args
INIT_FILE = Path('invokeai.yaml')
LEGACY_INIT_FILE = Path('invokeai.init')
class InvokeAISettings(BaseSettings):
'''
Runtime configuration settings in which default values are
read from an omegaconf .yaml file.
'''
initconf : ClassVar[DictConfig] = None
argparse_groups : ClassVar[Dict] = {}
def parse_args(self, argv: list=sys.argv[1:]):
parser = self.get_parser()
opt, _ = parser.parse_known_args(argv)
for name in self.__fields__:
if name not in self._excluded():
setattr(self, name, getattr(opt,name))
@classmethod
def add_parser_arguments(cls, parser):
env_prefix = cls.Config.env_prefix if hasattr(cls.Config,'env_prefix') else 'INVOKEAI_'
if 'type' in get_type_hints(cls):
default_settings_stanza = get_args(get_type_hints(cls)['type'])[0]
else:
default_settings_stanza = 'globals'
initconf = cls.initconf.get(default_settings_stanza) if cls.initconf and default_settings_stanza in cls.initconf else None
fields = cls.__fields__
cls.argparse_groups = {}
for name, field in fields.items():
if name not in cls._excluded():
env_name = env_prefix+f'{cls.cmd_name()}_{name}'
if initconf and name in initconf:
field.default = initconf.get(name)
if env_name in os.environ:
field.default = os.environ[env_name]
cls.add_field_argument(parser, name, field)
@classmethod
def cmd_name(self, command_field: str='type')->str:
hints = get_type_hints(self)
if command_field in hints:
return get_args(hints[command_field])[0]
else:
return 'globals'
@classmethod
def get_parser(cls)->ArgumentParser:
parser = ArgumentParser(
prog=cls.cmd_name(),
description=cls.__doc__,
)
cls.add_parser_arguments(parser)
return parser
@classmethod
def add_subparser(cls, parser: argparse.ArgumentParser):
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
@classmethod
def _excluded(self)->List[str]:
return ['type','initconf']
class Config:
env_file_encoding = 'utf-8'
arbitrary_types_allowed = True
env_prefix = 'INVOKEAI_'
case_sensitive = True
@classmethod
def customise_sources(
cls,
init_settings,
env_settings,
file_secret_settings,
):
return (
init_settings,
cls._omegaconf_settings_source,
env_settings,
file_secret_settings,
)
@classmethod
def _omegaconf_settings_source(cls, settings: BaseSettings) -> dict[str, Any]:
if initconf := InvokeAISettings.initconf:
return initconf.get(settings.cmd_name(),{})
else:
return {}
@classmethod
def add_field_argument(cls, command_parser, name: str, field, default_override = None):
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
if category := field.field_info.extra.get("category"):
if category not in cls.argparse_groups:
cls.argparse_groups[category] = command_parser.add_argument_group(category)
argparse_group = cls.argparse_groups[category]
else:
argparse_group = command_parser
if get_origin(field.type_) == Literal:
allowed_values = get_args(field.type_)
allowed_types = set()
for val in allowed_values:
allowed_types.add(type(val))
allowed_types_list = list(allowed_types)
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
argparse_group.add_argument(
f"--{name}",
dest=name,
type=field_type,
default=default,
choices=allowed_values,
help=field.field_info.description,
)
else:
argparse_group.add_argument(
f"--{name}",
dest=name,
type=field.type_,
default=default,
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
help=field.field_info.description,
)
def _find_root()->Path:
if os.environ.get("INVOKEAI_ROOT"):
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
elif (
os.environ.get("VIRTUAL_ENV")
and (Path(os.environ.get("VIRTUAL_ENV"), "..", INIT_FILE).exists()
or
Path(os.environ.get("VIRTUAL_ENV"), "..", LEGACY_INIT_FILE).exists()
)
):
root = Path(os.environ.get("VIRTUAL_ENV"), "..").resolve()
else:
root = Path("~/invokeai").expanduser().resolve()
return root
class InvokeAIAppConfig(InvokeAISettings):
'''
Application-wide settings.
'''
#fmt: off
type: Literal["globals"] = "globals"
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
infile : Path = Field(default=None, description='Path to a file of prompt commands to bulk generate from', category='Paths')
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
embedding_dir : Path = Field(default='embeddings', description='Path to InvokeAI textual inversion aembeddings directory', category='Paths')
lora_dir : Path = Field(default='loras', description='Path to InvokeAI LoRA model directory', category='Paths')
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths')
gfpgan_model_dir : Path = Field(default="./models/gfpgan/GFPGANv1.4.pth", description='Path to GFPGAN models directory.', category='Paths')
embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models')
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance')
max_loaded_models : int = Field(default=2, gt=0, description="Maximum number of models to keep in memory for rapid switching", category='Memory/Performance')
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
restore : bool = Field(default=True, description="Enable/disable face restoration code", category='Features')
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", category='Features')
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
#fmt: on
def __init__(self, conf: DictConfig = None, argv: List[str]=None, **kwargs):
'''
Initialize InvokeAIAppconfig.
:param conf: alternate Omegaconf dictionary object
:param argv: aternate sys.argv list
:param **kwargs: attributes to initialize with
'''
super().__init__(**kwargs)
# Set the runtime root directory. We parse command-line switches here
# in order to pick up the --root_dir option.
self.parse_args(argv)
if not conf:
try:
conf = OmegaConf.load(self.root_dir / INIT_FILE)
except:
pass
InvokeAISettings.initconf = conf
# parse args again in order to pick up settings in configuration file
self.parse_args(argv)
# restore initialization values
hints = get_type_hints(self)
for k in kwargs:
setattr(self,k,parse_obj_as(hints[k],kwargs[k]))
@property
def root_path(self)->Path:
'''
Path to the runtime root directory
'''
if self.root:
return self.root.expanduser()
else:
return self.find_root()
@property
def root_dir(self)->Path:
'''
Alias for above.
'''
return self.root_path
def _resolve(self,partial_path:Path)->Path:
return (self.root_path / partial_path).resolve()
@property
def output_path(self)->Path:
'''
Path to defaults outputs directory.
'''
return self._resolve(self.outdir)
@property
def model_conf_path(self)->Path:
'''
Path to models configuration file.
'''
return self._resolve(self.conf_path)
@property
def conf(self)->Path:
'''
Path to models configuration file (alias for model_conf_path).
'''
return self.model_conf_path
@property
def embedding_path(self)->Path:
'''
Path to the textual inversion embeddings directory.
'''
return self._resolve(self.embedding_dir) if self.embedding_dir else None
@property
def lora_path(self)->Path:
'''
Path to the LoRA models directory.
'''
return self._resolve(self.lora_dir) if self.lora_dir else None
@property
def autoconvert_path(self)->Path:
'''
Path to the directory containing models to be imported automatically at startup.
'''
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
@property
def gfpgan_model_path(self)->Path:
'''
Path to the GFPGAN model.
'''
return self._resolve(self.gfpgan_model_dir) if self.gfpgan_model_dir else None
@staticmethod
def find_root()->Path:
'''
Choose the runtime root directory when not specified on command line or
init file.
'''
return _find_root()
class InvokeAIWebConfig(InvokeAIAppConfig):
'''
Web-specific settings
'''
#fmt: off
type : Literal["web"] = "web"
allow_origins : List = Field(default=[], description="Allowed CORS origins", category='Cross-Origin Resource Sharing')
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", category='Cross-Origin Resource Sharing')
allow_methods : List = Field(default=["*"], description="Methods allowed for CORS", category='Cross-Origin Resource Sharing')
allow_headers : List = Field(default=["*"], description="Headers allowed for CORS", category='Cross-Origin Resource Sharing')
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
#fmt: on

View File

@ -1,6 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any
from typing import Any, Dict, TypedDict, Union
from invokeai.app.api.models.images import ProgressImage
from invokeai.app.util.misc import get_timestamp

View File

@ -3,7 +3,6 @@
import copy
import itertools
import uuid
from types import NoneType
from typing import (
Annotated,
Any,
@ -14,9 +13,10 @@ from typing import (
get_origin,
get_type_hints,
)
NoneType = type(None)
import networkx as nx
from pydantic import BaseModel, root_validator, validator
from pydantic import BaseModel, root_validator, validator, Extra
from pydantic.fields import Field
from ..invocations import *
@ -25,6 +25,7 @@ from ..invocations.baseinvocation import (
BaseInvocationOutput,
InvocationContext,
)
from .config import InvokeAISettings
class EdgeConnection(BaseModel):
@ -211,9 +212,10 @@ class CollectInvocation(BaseInvocation):
InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
class Graph(BaseModel):
class Graph(InvokeAISettings):
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
type: Literal["graph"] = "graph"
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
description="The nodes in this graph", default_factory=dict
@ -805,7 +807,7 @@ class GraphExecutionState(BaseModel):
]
}
def next(self) -> BaseInvocation | None:
def next(self) -> Union[BaseInvocation, None]:
"""Gets the next node ready to execute."""
# TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes
@ -1154,7 +1156,7 @@ class ExposedNodeOutput(BaseModel):
field: str = Field(description="The field name of the output")
alias: str = Field(description="The alias of the output")
class LibraryGraph(BaseModel):
class LibraryGraph(InvokeAISettings):
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
graph: Graph = Field(description="The graph")
name: str = Field(description="The name of the graph")
@ -1162,6 +1164,9 @@ class LibraryGraph(BaseModel):
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
exposed_outputs: list[ExposedNodeOutput] = Field(description="The outputs exposed by this graph", default_factory=list)
class Config:
extra='allow'
@validator('exposed_inputs', 'exposed_outputs')
def validate_exposed_aliases(cls, v):
if len(v) != len(set(i.alias for i in v)):

View File

@ -3,7 +3,7 @@
import time
from abc import ABC, abstractmethod
from queue import Queue
from typing import Union
from pydantic import BaseModel, Field
@ -22,7 +22,7 @@ class InvocationQueueABC(ABC):
pass
@abstractmethod
def put(self, item: InvocationQueueItem | None) -> None:
def put(self, item: Union[InvocationQueueItem, None]) -> None:
pass
@abstractmethod
@ -57,7 +57,7 @@ class MemoryInvocationQueue(InvocationQueueABC):
return item
def put(self, item: InvocationQueueItem | None) -> None:
def put(self, item: Union[InvocationQueueItem, None]) -> None:
self.__queue.put(item)
def cancel(self, graph_execution_state_id: str) -> None:

View File

@ -8,6 +8,7 @@ from .image_storage import ImageStorageBase
from .restoration_services import RestorationServices
from .invocation_queue import InvocationQueueABC
from .item_storage import ItemStorageABC
from .config import InvokeAISettings
class InvocationServices:
"""Services that can be used by invocations"""
@ -19,6 +20,7 @@ class InvocationServices:
queue: InvocationQueueABC
model_manager: ModelManager
restoration: RestorationServices
configuration: InvokeAISettings
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
graph_library: ItemStorageABC["LibraryGraph"]
@ -37,6 +39,7 @@ class InvocationServices:
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: "InvocationProcessorABC",
restoration: RestorationServices,
configuration: InvokeAISettings=None,
):
self.model_manager = model_manager
self.events = events
@ -48,3 +51,4 @@ class InvocationServices:
self.graph_execution_manager = graph_execution_manager
self.processor = processor
self.restoration = restoration
self.configuration = configuration

View File

@ -2,6 +2,7 @@
from abc import ABC
from threading import Event, Thread
from typing import Union
from ..invocations.baseinvocation import InvocationContext
from .graph import Graph, GraphExecutionState
@ -21,7 +22,7 @@ class Invoker:
def invoke(
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
) -> str | None:
) -> Union[str, None]:
"""Determines the next node to invoke and returns the id of the invoked node, or None if there are no nodes to execute"""
# Get the next invocation
@ -44,7 +45,7 @@ class Invoker:
return invocation.id
def create_execution_state(self, graph: Graph | None = None) -> GraphExecutionState:
def create_execution_state(self, graph: Union[Graph, None] = None) -> GraphExecutionState:
"""Creates a new execution state for the given graph"""
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
self.services.graph_execution_manager.set(new_state)

View File

@ -4,7 +4,7 @@ import os
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
from typing import Dict
from typing import Dict, Union
import torch
@ -56,7 +56,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
if name in self.__cache:
del self.__cache[name]
def __get_cache(self, name: str) -> torch.Tensor|None:
def __get_cache(self, name: str) -> Union[torch.Tensor,None]:
return None if name not in self.__cache else self.__cache[name]
def __set_cache(self, name: str, data: torch.Tensor):

View File

@ -7,21 +7,20 @@ from omegaconf import OmegaConf
from pathlib import Path
import invokeai.version
from .config import InvokeAISettings
from ...backend import ModelManager
from ...backend.util import choose_precision, choose_torch_device
from ...backend import Globals
# TODO: Replace with an abstract class base ModelManagerBase
def get_model_manager(config: Args) -> ModelManager:
if not config.conf:
config_file = os.path.join(Globals.root, "configs", "models.yaml")
if not os.path.exists(config_file):
def get_model_manager(config:InvokeAISettings) -> ModelManager:
model_config = config.model_conf_path
if not model_config.exists():
report_model_error(
config, FileNotFoundError(f"The file {config_file} could not be found.")
config, FileNotFoundError(f"The file {model_config} could not be found.")
)
print(f">> {invokeai.version.__app_name__}, version {invokeai.version.__version__}")
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
print(f'>> InvokeAI runtime directory is "{config.root_dir}"')
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
@ -31,20 +30,7 @@ def get_model_manager(config: Args) -> ModelManager:
import diffusers
diffusers.logging.set_verbosity_error()
# normalize the config directory relative to root
if not os.path.isabs(config.conf):
config.conf = os.path.normpath(os.path.join(Globals.root, config.conf))
if config.embeddings:
if not os.path.isabs(config.embedding_path):
embedding_path = os.path.normpath(
os.path.join(Globals.root, config.embedding_path)
)
else:
embedding_path = config.embedding_path
else:
embedding_path = None
# migrate legacy models
ModelManager.migrate_models()
@ -57,11 +43,11 @@ def get_model_manager(config: Args) -> ModelManager:
else choose_precision(device)
model_manager = ModelManager(
OmegaConf.load(config.conf),
OmegaConf.load(model_config),
precision=precision,
device_type=device,
max_loaded_models=config.max_loaded_models,
embedding_path = Path(embedding_path),
embedding_path = embedding_path,
)
except (FileNotFoundError, TypeError, AssertionError) as e:
report_model_error(config, e)
@ -71,12 +57,10 @@ def get_model_manager(config: Args) -> ModelManager:
# try to autoconvert new models
# autoimport new .ckpt files
if path := config.autoconvert:
model_manager.autoconvert_weights(
conf_path=config.conf,
weights_directory=path,
if config.autoconvert_path:
model_manager.heuristic_import(
config.autoconvert_path,
)
return model_manager
def report_model_error(opt: Namespace, e: Exception):

View File

@ -16,6 +16,7 @@ import os.path as osp
from argparse import Namespace
from pathlib import Path
from typing import Union
from pydantic import BaseSettings
Globals = Namespace()
@ -120,3 +121,15 @@ def global_cache_dir(subdir: Union[str, Path] = "") -> Path:
return Path(home, subdir)
else:
return Path(Globals.root, "models", subdir)
def copy_conf_to_globals(conf: Union[dict,BaseSettings]):
'''
Given a dict or dict-like object, copy its keys and
values into the Globals Namespace. This is a transitional
workaround until we remove Globals entirely.
'''
if isinstance(conf,BaseSettings):
conf = conf.dict()
for key in conf.keys():
if key is not None:
setattr(Globals,key,conf[key])

View File

@ -8,7 +8,6 @@ from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, LibraryGraph, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
import pytest
@pytest.fixture
def simple_graph():
g = Graph()

View File

@ -316,7 +316,7 @@ def test_graph_iterator_invalid_if_multiple_inputs():
def test_graph_iterator_invalid_if_input_not_list():
g = Graph()
n1 = TextToImageInvocation(id = "1", promopt = "Banana sushi")
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n2 = IterateInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)

99
tests/test_config.py Normal file
View File

@ -0,0 +1,99 @@
import os
import pytest
from omegaconf import OmegaConf
from pathlib import Path
os.environ['INVOKEAI_ROOT']='/tmp'
from invokeai.app.services.config import InvokeAIAppConfig, InvokeAISettings
from invokeai.app.invocations.generate import TextToImageInvocation
init1 = OmegaConf.create(
'''
globals:
nsfw_checker: False
max_loaded_models: 5
history:
count: 100
txt2img:
steps: 18
scheduler: k_heun
width: 768
img2img:
width: 1024
height: 1024
'''
)
init2 = OmegaConf.create(
'''
globals:
nsfw_checker: True
max_loaded_models: 2
history:
count: 10
'''
)
def test_use_init():
# note that we explicitly set omegaconf dict and argv here
# so that the values aren't read from ~invokeai/invokeai.yaml and
# sys.argv respectively.
conf1 = InvokeAIAppConfig(init1,[])
assert conf1
assert conf1.max_loaded_models==5
assert not conf1.nsfw_checker
conf2 = InvokeAIAppConfig(init2,[])
assert conf2
assert conf2.nsfw_checker
assert conf2.max_loaded_models==2
assert not hasattr(conf2,'invalid_attribute')
def test_argv_override():
conf = InvokeAIAppConfig(init1,['--nsfw_checker','--max_loaded=10'])
assert conf.nsfw_checker
assert conf.max_loaded_models==10
assert conf.outdir==Path('outputs') # this is the default
def test_env_override():
# argv overrides
conf = InvokeAIAppConfig(conf=init1,argv=['--max_loaded=10'])
assert conf.nsfw_checker==False
os.environ['INVOKEAI_globals_nsfw_checker'] = 'True'
conf = InvokeAIAppConfig(conf=init1,argv=['--max_loaded=10'])
assert conf.nsfw_checker==True
conf = InvokeAIAppConfig(conf=init1,argv=['--no-nsfw_checker','--max_loaded=10'])
assert conf.nsfw_checker==False
conf = InvokeAIAppConfig(conf=init1,argv=[],max_loaded_models=20)
assert conf.max_loaded_models==20
# have to comment this one out because of a race condition in setting same
# environment variable in the CI test environment
# assert conf.root==Path('/tmp')
def test_invocation():
InvokeAISettings.initconf=init1
invocation = TextToImageInvocation(id='foobar')
assert invocation.steps==18
assert invocation.scheduler=='k_heun'
assert invocation.height==512 # default
invocation = TextToImageInvocation(id='foobar2',steps=30)
assert invocation.steps==30
def test_type_coercion():
conf = InvokeAIAppConfig(argv=['--root=/tmp/foobar'])
assert conf.root==Path('/tmp/foobar')
assert isinstance(conf.root,Path)
conf = InvokeAIAppConfig(argv=['--root=/tmp/foobar'],root='/tmp/different')
assert conf.root==Path('/tmp/different')
assert isinstance(conf.root,Path)