mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
25 Commits
fix/diffus
...
lstein/enh
Author | SHA1 | Date | |
---|---|---|---|
8b281b792c | |||
447f590fa4 | |||
3f301d3db3 | |||
4e21f5f046 | |||
56462d62ae | |||
e67f2d7f6f | |||
35ef37e481 | |||
cfb08d97f9 | |||
c3a95bda2f | |||
a64bd95df2 | |||
575feb8aee | |||
8318d22d63 | |||
ca04f08668 | |||
5908c75a7d | |||
94fac09cb7 | |||
49dd3d93e7 | |||
2c59164120 | |||
f6fdd8b805 | |||
600258c24c | |||
57d032d7e6 | |||
196c21b7c9 | |||
95f6c85a29 | |||
a0df350608 | |||
1815c15ae2 | |||
629697c4dd |
@ -9,7 +9,8 @@ from ..services.default_graphs import create_system_graphs
|
||||
|
||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from ...backend import Globals
|
||||
from ...backend.globals import Globals, copy_conf_to_globals
|
||||
from ..services.config import InvokeAIWebConfig
|
||||
from ..services.model_manager_initializer import get_model_manager
|
||||
from ..services.restoration_services import RestorationServices
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
||||
@ -45,11 +46,7 @@ class ApiDependencies:
|
||||
|
||||
@staticmethod
|
||||
def initialize(config, event_handler_id: int):
|
||||
Globals.try_patchmatch = config.patchmatch
|
||||
Globals.always_use_cpu = config.always_use_cpu
|
||||
Globals.internet_available = config.internet_available and check_internet()
|
||||
Globals.disable_xformers = not config.xformers
|
||||
Globals.ckpt_convert = config.ckpt_convert
|
||||
copy_conf_to_globals(config)
|
||||
|
||||
# TODO: Use a logger
|
||||
print(f">> Internet connectivity is {Globals.internet_available}")
|
||||
@ -84,6 +81,7 @@ class ApiDependencies:
|
||||
),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config),
|
||||
configuration=config,
|
||||
)
|
||||
|
||||
create_system_graphs(services.graph_library)
|
||||
|
@ -8,10 +8,6 @@ from fastapi.routing import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field, parse_obj_as
|
||||
from pathlib import Path
|
||||
from ..dependencies import ApiDependencies
|
||||
from invokeai.backend.globals import Globals, global_converted_ckpts_dir
|
||||
from invokeai.backend.args import Args
|
||||
|
||||
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
|
@ -12,12 +12,11 @@ from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pydantic.schema import schema
|
||||
|
||||
from ..backend import Args
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import images, sessions, models
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.config import InvokeAIWebConfig
|
||||
|
||||
# Create the app
|
||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||
@ -33,30 +32,15 @@ app.add_middleware(
|
||||
middleware_id=event_handler_id,
|
||||
)
|
||||
|
||||
# Add CORS
|
||||
# TODO: use configuration for this
|
||||
origins = []
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
socket_io = SocketIO(app)
|
||||
|
||||
config = {}
|
||||
|
||||
web_config = {}
|
||||
|
||||
# Add startup event to load dependencies
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
config = Args()
|
||||
config.parse_args()
|
||||
|
||||
ApiDependencies.initialize(
|
||||
config=config, event_handler_id=event_handler_id
|
||||
config=web_config, event_handler_id=event_handler_id
|
||||
)
|
||||
|
||||
|
||||
@ -146,12 +130,21 @@ def overridden_redoc():
|
||||
|
||||
|
||||
def invoke_api():
|
||||
# parse command-line settings, environment and the init file
|
||||
# (this is a module global)
|
||||
global web_config
|
||||
web_config = InvokeAIWebConfig()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=web_config.allow_origins,
|
||||
allow_credentials=web_config.allow_credentials,
|
||||
allow_methods=web_config.allow_methods,
|
||||
allow_headers=web_config.allow_headers,
|
||||
)
|
||||
# Start our own event loop for eventing usage
|
||||
# TODO: determine if there's a better way to do this
|
||||
loop = asyncio.new_event_loop()
|
||||
config = uvicorn.Config(app=app, host="0.0.0.0", port=9090, loop=loop)
|
||||
config = uvicorn.Config(app=app, host=web_config.host, port=web_config.port, loop=loop)
|
||||
# Use access_log to turn off logging
|
||||
|
||||
server = uvicorn.Server(config)
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
|
@ -1,91 +1,52 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import argparse
|
||||
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Any, Callable, Iterable, Literal, Union, get_args, get_type_hints
|
||||
from pydantic import Field
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from ..services.config import InvokeAISettings
|
||||
from ..invocations.image import ImageField
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph, GraphInvocation, Edge
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph, Edge
|
||||
from ..services.invoker import Invoker
|
||||
|
||||
|
||||
def add_field_argument(command_parser, name: str, field, default_override = None):
|
||||
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=default,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
|
||||
def add_parsers(
|
||||
subparsers,
|
||||
commands: list[type],
|
||||
command_field: str = "type",
|
||||
exclude_fields: list[str] = ["id", "type"],
|
||||
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
|
||||
add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
|
||||
):
|
||||
"""Adds parsers for each command to the subparsers"""
|
||||
|
||||
# Create subparsers for each command
|
||||
for command in commands:
|
||||
hints = get_type_hints(command)
|
||||
cmd_name = get_args(hints[command_field])[0]
|
||||
command_parser = subparsers.add_parser(cmd_name, help=command.__doc__)
|
||||
|
||||
name = command.cmd_name()
|
||||
command_parser = subparsers.add_parser(name, help=command.__doc__)
|
||||
if add_arguments is not None:
|
||||
add_arguments(command_parser)
|
||||
|
||||
# Convert all fields to arguments
|
||||
fields = command.__fields__ # type: ignore
|
||||
for name, field in fields.items():
|
||||
if name in exclude_fields:
|
||||
continue
|
||||
|
||||
add_field_argument(command_parser, name, field)
|
||||
|
||||
command.add_parser_arguments(command_parser)
|
||||
|
||||
def add_graph_parsers(
|
||||
subparsers,
|
||||
graphs: list[LibraryGraph],
|
||||
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
|
||||
add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
|
||||
):
|
||||
for graph in graphs:
|
||||
command_parser = subparsers.add_parser(graph.name, help=graph.description)
|
||||
|
||||
if add_arguments is not None:
|
||||
add_arguments(command_parser)
|
||||
graph.add_parser_arguments(command_parser)
|
||||
|
||||
# Add arguments for inputs
|
||||
for exposed_input in graph.exposed_inputs:
|
||||
node = graph.graph.get_node(exposed_input.node_path)
|
||||
field = node.__fields__[exposed_input.field]
|
||||
default_override = getattr(node, exposed_input.field)
|
||||
add_field_argument(command_parser, exposed_input.alias, field, default_override)
|
||||
|
||||
graph.add_field_argument(command_parser, exposed_input.alias, field, default_override)
|
||||
|
||||
class CliContext:
|
||||
invoker: Invoker
|
||||
@ -130,7 +91,7 @@ class ExitCli(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BaseCommand(ABC, BaseModel):
|
||||
class BaseCommand(ABC, InvokeAISettings):
|
||||
"""A CLI command"""
|
||||
|
||||
# All commands must include a type name like this:
|
||||
|
@ -10,9 +10,10 @@ import shlex
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
|
||||
|
||||
from ...backend import ModelManager, Globals
|
||||
from ...backend import ModelManager
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from .commands import BaseCommand
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
# singleton object, class variable
|
||||
completer = None
|
||||
@ -130,13 +131,13 @@ class Completer(object):
|
||||
readline.redisplay()
|
||||
self.linebuffer = None
|
||||
|
||||
def set_autocompleter(model_manager: ModelManager) -> Completer:
|
||||
def set_autocompleter(services: InvocationServices) -> Completer:
|
||||
global completer
|
||||
|
||||
if completer:
|
||||
return completer
|
||||
|
||||
completer = Completer(model_manager)
|
||||
completer = Completer(services.model_manager)
|
||||
|
||||
readline.set_completer(completer.complete)
|
||||
# pyreadline3 does not have a set_auto_history() method
|
||||
@ -152,7 +153,7 @@ def set_autocompleter(model_manager: ModelManager) -> Completer:
|
||||
readline.parse_and_bind("set skip-completed-text on")
|
||||
readline.parse_and_bind("set show-all-if-ambiguous on")
|
||||
|
||||
histfile = Path(Globals.root, ".invoke_history")
|
||||
histfile = Path(services.configuration.root_dir / ".invoke_history")
|
||||
try:
|
||||
readline.read_history_file(histfile)
|
||||
readline.set_history_length(1000)
|
||||
|
@ -19,7 +19,6 @@ from .services.default_graphs import create_system_graphs
|
||||
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from ..backend import Args
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, get_graph_execution_history
|
||||
from .cli.completer import set_autocompleter
|
||||
from .invocations import *
|
||||
@ -35,7 +34,8 @@ from .services.invocation_services import InvocationServices
|
||||
from .services.invoker import Invoker
|
||||
from .services.processor import DefaultInvocationProcessor
|
||||
from .services.sqlite import SqliteItemStorage
|
||||
|
||||
from .services.config import InvokeAIAppConfig
|
||||
from ..backend.globals import copy_conf_to_globals # temporary workaround for code depending on Globals
|
||||
|
||||
class CliCommand(BaseModel):
|
||||
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
||||
@ -190,24 +190,20 @@ def invoke_all(context: CliContext):
|
||||
|
||||
|
||||
def invoke_cli():
|
||||
config = Args()
|
||||
config.parse_args()
|
||||
config = InvokeAIAppConfig()
|
||||
copy_conf_to_globals(config) # temporary workaround
|
||||
model_manager = get_model_manager(config)
|
||||
|
||||
# This initializes the autocompleter and returns it.
|
||||
# Currently nothing is done with the returned Completer
|
||||
# object, but the object can be used to change autocompletion
|
||||
# behavior on the fly, if desired.
|
||||
completer = set_autocompleter(model_manager)
|
||||
|
||||
events = EventServiceBase()
|
||||
|
||||
output_folder = config.output_path
|
||||
metadata = PngMetadataService()
|
||||
|
||||
output_folder = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "../../../outputs")
|
||||
)
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
|
||||
@ -226,6 +222,7 @@ def invoke_cli():
|
||||
),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config),
|
||||
configuration=config,
|
||||
)
|
||||
|
||||
system_graphs = create_system_graphs(services.graph_library)
|
||||
@ -242,6 +239,8 @@ def invoke_cli():
|
||||
|
||||
context = CliContext(invoker, session, parser)
|
||||
|
||||
set_autocompleter(services)
|
||||
|
||||
while True:
|
||||
try:
|
||||
cmd_input = input("invoke> ")
|
||||
@ -285,8 +284,17 @@ def invoke_cli():
|
||||
command = CliCommand(command = invocation)
|
||||
context.graph_nodes[invocation.id] = system_graph.id
|
||||
else:
|
||||
args["id"] = current_id
|
||||
command = CliCommand(command=args)
|
||||
if "id" in args:
|
||||
args["id"] = args["id"] or current_id
|
||||
|
||||
# remove extraneous fields from initialization
|
||||
exclude = ['link','link_node']
|
||||
command_args = dict()
|
||||
for key,value in args.items():
|
||||
if key not in exclude:
|
||||
command_args[key]=value
|
||||
|
||||
command = CliCommand(command=command_args)
|
||||
|
||||
if command is None:
|
||||
continue
|
||||
|
@ -4,10 +4,10 @@ from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, BaseSettings, Field
|
||||
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
from ..services.config import InvokeAISettings
|
||||
|
||||
class InvocationContext:
|
||||
services: InvocationServices
|
||||
@ -36,7 +36,7 @@ class BaseInvocationOutput(BaseModel):
|
||||
return tuple(subclasses)
|
||||
|
||||
|
||||
class BaseInvocation(ABC, BaseModel):
|
||||
class BaseInvocation(ABC, InvokeAISettings):
|
||||
"""A node to process inputs and produce outputs.
|
||||
May use dependency injection in __init__ to receive providers.
|
||||
"""
|
||||
@ -101,8 +101,8 @@ class CustomisedSchemaExtra(TypedDict):
|
||||
ui: UIConfig
|
||||
|
||||
|
||||
class InvocationConfig(BaseModel.Config):
|
||||
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
|
||||
class InvocationConfig(BaseSettings.Config):
|
||||
"""Customizes pydantic's BaseSettings.Config class for use by Invocations.
|
||||
|
||||
Provide `schema_extra` a `ui` dict to add hints for generated UIs.
|
||||
|
||||
|
379
invokeai/app/services/config.py
Normal file
379
invokeai/app/services/config.py
Normal file
@ -0,0 +1,379 @@
|
||||
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein)
|
||||
|
||||
'''Invokeai configuration system.
|
||||
|
||||
Arguments and fields are taken from the pydantic definition of the
|
||||
model. Defaults can be set by creating a yaml configuration file that
|
||||
has top-level keys corresponding to an invocation name, a command, or
|
||||
"globals" for global values such as `xformers_enabled`. Currently
|
||||
graphs cannot be configured this way, but their constituents can be.
|
||||
|
||||
[file: invokeai.yaml]
|
||||
|
||||
globals:
|
||||
nsfw_checker: False
|
||||
max_loaded_models: 5
|
||||
|
||||
txt2img:
|
||||
steps: 20
|
||||
scheduler: k_heun
|
||||
width: 768
|
||||
|
||||
img2img:
|
||||
width: 1024
|
||||
height: 1024
|
||||
|
||||
The default name of the configuration file is `invokeai.yaml`, located
|
||||
in INVOKEAI_ROOT. You can use any OmegaConf dictionary by passing it
|
||||
to the config object at initialization time:
|
||||
|
||||
omegaconf = OmegaConf.load('/tmp/init.yaml')
|
||||
conf = InvokeAIAppConfig(conf=omegaconf)
|
||||
|
||||
By default, InvokeAIAppConfig will parse the contents of argv at
|
||||
initialization time. You may pass a list of strings in the optional
|
||||
`argv` argument to use instead of the system argv:
|
||||
|
||||
conf = InvokeAIAppConfig(arg=['--xformers_enabled'])
|
||||
|
||||
It is also possible to set a value at initialization time. This value
|
||||
has highest priority.
|
||||
|
||||
conf = InvokeAIAppConfig(xformers_enabled=True)
|
||||
|
||||
Any setting can be overwritten by setting an environment variable of
|
||||
form: "INVOKEAI_<command>_<value>", as in:
|
||||
|
||||
export INVOKEAI_txt2img_steps=30
|
||||
|
||||
Order of precedence (from highest):
|
||||
1) initialization options
|
||||
2) command line options
|
||||
3) environment variable options
|
||||
4) config file options
|
||||
5) pydantic defaults
|
||||
|
||||
Typical usage:
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.invocations.generate import TextToImageInvocation
|
||||
|
||||
# get global configuration and print its nsfw_checker value
|
||||
conf = InvokeAIAppConfig()
|
||||
print(conf.nsfw_checker)
|
||||
|
||||
# get the text2image invocation and print its step value
|
||||
text2image = TextToImageInvocation()
|
||||
print(text2image.steps)
|
||||
|
||||
Computed properties:
|
||||
|
||||
The InvokeAIAppConfig object has a series of properties that
|
||||
resolve paths relative to the runtime root directory. They each return
|
||||
a Path object:
|
||||
|
||||
root_path - path to InvokeAI root
|
||||
output_path - path to default outputs directory
|
||||
model_conf_path - path to models.yaml
|
||||
conf - alias for the above
|
||||
embedding_path - path to the embeddings directory
|
||||
lora_path - path to the LoRA directory
|
||||
|
||||
|
||||
'''
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
from pathlib import Path
|
||||
from pydantic import BaseSettings, Field, parse_obj_as
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Union, get_origin, get_type_hints, get_args
|
||||
|
||||
INIT_FILE = Path('invokeai.yaml')
|
||||
LEGACY_INIT_FILE = Path('invokeai.init')
|
||||
|
||||
class InvokeAISettings(BaseSettings):
|
||||
'''
|
||||
Runtime configuration settings in which default values are
|
||||
read from an omegaconf .yaml file.
|
||||
'''
|
||||
initconf : ClassVar[DictConfig] = None
|
||||
argparse_groups : ClassVar[Dict] = {}
|
||||
|
||||
def parse_args(self, argv: list=sys.argv[1:]):
|
||||
parser = self.get_parser()
|
||||
opt, _ = parser.parse_known_args(argv)
|
||||
for name in self.__fields__:
|
||||
if name not in self._excluded():
|
||||
setattr(self, name, getattr(opt,name))
|
||||
|
||||
@classmethod
|
||||
def add_parser_arguments(cls, parser):
|
||||
env_prefix = cls.Config.env_prefix if hasattr(cls.Config,'env_prefix') else 'INVOKEAI_'
|
||||
if 'type' in get_type_hints(cls):
|
||||
default_settings_stanza = get_args(get_type_hints(cls)['type'])[0]
|
||||
else:
|
||||
default_settings_stanza = 'globals'
|
||||
initconf = cls.initconf.get(default_settings_stanza) if cls.initconf and default_settings_stanza in cls.initconf else None
|
||||
|
||||
fields = cls.__fields__
|
||||
cls.argparse_groups = {}
|
||||
for name, field in fields.items():
|
||||
if name not in cls._excluded():
|
||||
env_name = env_prefix+f'{cls.cmd_name()}_{name}'
|
||||
if initconf and name in initconf:
|
||||
field.default = initconf.get(name)
|
||||
if env_name in os.environ:
|
||||
field.default = os.environ[env_name]
|
||||
cls.add_field_argument(parser, name, field)
|
||||
|
||||
|
||||
@classmethod
|
||||
def cmd_name(self, command_field: str='type')->str:
|
||||
hints = get_type_hints(self)
|
||||
if command_field in hints:
|
||||
return get_args(hints[command_field])[0]
|
||||
else:
|
||||
return 'globals'
|
||||
|
||||
@classmethod
|
||||
def get_parser(cls)->ArgumentParser:
|
||||
parser = ArgumentParser(
|
||||
prog=cls.cmd_name(),
|
||||
description=cls.__doc__,
|
||||
)
|
||||
cls.add_parser_arguments(parser)
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def add_subparser(cls, parser: argparse.ArgumentParser):
|
||||
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
|
||||
|
||||
@classmethod
|
||||
def _excluded(self)->List[str]:
|
||||
return ['type','initconf']
|
||||
|
||||
class Config:
|
||||
env_file_encoding = 'utf-8'
|
||||
arbitrary_types_allowed = True
|
||||
env_prefix = 'INVOKEAI_'
|
||||
case_sensitive = True
|
||||
@classmethod
|
||||
def customise_sources(
|
||||
cls,
|
||||
init_settings,
|
||||
env_settings,
|
||||
file_secret_settings,
|
||||
):
|
||||
return (
|
||||
init_settings,
|
||||
cls._omegaconf_settings_source,
|
||||
env_settings,
|
||||
file_secret_settings,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _omegaconf_settings_source(cls, settings: BaseSettings) -> dict[str, Any]:
|
||||
if initconf := InvokeAISettings.initconf:
|
||||
return initconf.get(settings.cmd_name(),{})
|
||||
else:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def add_field_argument(cls, command_parser, name: str, field, default_override = None):
|
||||
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
||||
if category := field.field_info.extra.get("category"):
|
||||
if category not in cls.argparse_groups:
|
||||
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
||||
argparse_group = cls.argparse_groups[category]
|
||||
else:
|
||||
argparse_group = command_parser
|
||||
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
|
||||
help=field.field_info.description,
|
||||
)
|
||||
def _find_root()->Path:
|
||||
if os.environ.get("INVOKEAI_ROOT"):
|
||||
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
|
||||
elif (
|
||||
os.environ.get("VIRTUAL_ENV")
|
||||
and (Path(os.environ.get("VIRTUAL_ENV"), "..", INIT_FILE).exists()
|
||||
or
|
||||
Path(os.environ.get("VIRTUAL_ENV"), "..", LEGACY_INIT_FILE).exists()
|
||||
)
|
||||
):
|
||||
root = Path(os.environ.get("VIRTUAL_ENV"), "..").resolve()
|
||||
else:
|
||||
root = Path("~/invokeai").expanduser().resolve()
|
||||
return root
|
||||
|
||||
class InvokeAIAppConfig(InvokeAISettings):
|
||||
'''
|
||||
Application-wide settings.
|
||||
'''
|
||||
#fmt: off
|
||||
type: Literal["globals"] = "globals"
|
||||
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
|
||||
infile : Path = Field(default=None, description='Path to a file of prompt commands to bulk generate from', category='Paths')
|
||||
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
||||
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
|
||||
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
||||
embedding_dir : Path = Field(default='embeddings', description='Path to InvokeAI textual inversion aembeddings directory', category='Paths')
|
||||
lora_dir : Path = Field(default='loras', description='Path to InvokeAI LoRA model directory', category='Paths')
|
||||
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths')
|
||||
gfpgan_model_dir : Path = Field(default="./models/gfpgan/GFPGANv1.4.pth", description='Path to GFPGAN models directory.', category='Paths')
|
||||
embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models')
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
||||
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance')
|
||||
max_loaded_models : int = Field(default=2, gt=0, description="Maximum number of models to keep in memory for rapid switching", category='Memory/Performance')
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
||||
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
||||
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
|
||||
restore : bool = Field(default=True, description="Enable/disable face restoration code", category='Features')
|
||||
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", category='Features')
|
||||
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
||||
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
|
||||
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
||||
#fmt: on
|
||||
|
||||
def __init__(self, conf: DictConfig = None, argv: List[str]=None, **kwargs):
|
||||
'''
|
||||
Initialize InvokeAIAppconfig.
|
||||
:param conf: alternate Omegaconf dictionary object
|
||||
:param argv: aternate sys.argv list
|
||||
:param **kwargs: attributes to initialize with
|
||||
'''
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Set the runtime root directory. We parse command-line switches here
|
||||
# in order to pick up the --root_dir option.
|
||||
self.parse_args(argv)
|
||||
if not conf:
|
||||
try:
|
||||
conf = OmegaConf.load(self.root_dir / INIT_FILE)
|
||||
except:
|
||||
pass
|
||||
InvokeAISettings.initconf = conf
|
||||
|
||||
# parse args again in order to pick up settings in configuration file
|
||||
self.parse_args(argv)
|
||||
|
||||
# restore initialization values
|
||||
hints = get_type_hints(self)
|
||||
for k in kwargs:
|
||||
setattr(self,k,parse_obj_as(hints[k],kwargs[k]))
|
||||
|
||||
@property
|
||||
def root_path(self)->Path:
|
||||
'''
|
||||
Path to the runtime root directory
|
||||
'''
|
||||
if self.root:
|
||||
return self.root.expanduser()
|
||||
else:
|
||||
return self.find_root()
|
||||
|
||||
@property
|
||||
def root_dir(self)->Path:
|
||||
'''
|
||||
Alias for above.
|
||||
'''
|
||||
return self.root_path
|
||||
|
||||
def _resolve(self,partial_path:Path)->Path:
|
||||
return (self.root_path / partial_path).resolve()
|
||||
|
||||
@property
|
||||
def output_path(self)->Path:
|
||||
'''
|
||||
Path to defaults outputs directory.
|
||||
'''
|
||||
return self._resolve(self.outdir)
|
||||
|
||||
@property
|
||||
def model_conf_path(self)->Path:
|
||||
'''
|
||||
Path to models configuration file.
|
||||
'''
|
||||
return self._resolve(self.conf_path)
|
||||
|
||||
@property
|
||||
def conf(self)->Path:
|
||||
'''
|
||||
Path to models configuration file (alias for model_conf_path).
|
||||
'''
|
||||
return self.model_conf_path
|
||||
|
||||
@property
|
||||
def embedding_path(self)->Path:
|
||||
'''
|
||||
Path to the textual inversion embeddings directory.
|
||||
'''
|
||||
return self._resolve(self.embedding_dir) if self.embedding_dir else None
|
||||
|
||||
@property
|
||||
def lora_path(self)->Path:
|
||||
'''
|
||||
Path to the LoRA models directory.
|
||||
'''
|
||||
return self._resolve(self.lora_dir) if self.lora_dir else None
|
||||
|
||||
@property
|
||||
def autoconvert_path(self)->Path:
|
||||
'''
|
||||
Path to the directory containing models to be imported automatically at startup.
|
||||
'''
|
||||
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
|
||||
|
||||
@property
|
||||
def gfpgan_model_path(self)->Path:
|
||||
'''
|
||||
Path to the GFPGAN model.
|
||||
'''
|
||||
return self._resolve(self.gfpgan_model_dir) if self.gfpgan_model_dir else None
|
||||
|
||||
@staticmethod
|
||||
def find_root()->Path:
|
||||
'''
|
||||
Choose the runtime root directory when not specified on command line or
|
||||
init file.
|
||||
'''
|
||||
return _find_root()
|
||||
|
||||
class InvokeAIWebConfig(InvokeAIAppConfig):
|
||||
'''
|
||||
Web-specific settings
|
||||
'''
|
||||
#fmt: off
|
||||
type : Literal["web"] = "web"
|
||||
allow_origins : List = Field(default=[], description="Allowed CORS origins", category='Cross-Origin Resource Sharing')
|
||||
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", category='Cross-Origin Resource Sharing')
|
||||
allow_methods : List = Field(default=["*"], description="Methods allowed for CORS", category='Cross-Origin Resource Sharing')
|
||||
allow_headers : List = Field(default=["*"], description="Headers allowed for CORS", category='Cross-Origin Resource Sharing')
|
||||
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
||||
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
|
||||
#fmt: on
|
@ -1,6 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Dict, TypedDict, Union
|
||||
from invokeai.app.api.models.images import ProgressImage
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
|
||||
|
@ -3,7 +3,6 @@
|
||||
import copy
|
||||
import itertools
|
||||
import uuid
|
||||
from types import NoneType
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
@ -14,9 +13,10 @@ from typing import (
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
)
|
||||
NoneType = type(None)
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import BaseModel, root_validator, validator
|
||||
from pydantic import BaseModel, root_validator, validator, Extra
|
||||
from pydantic.fields import Field
|
||||
|
||||
from ..invocations import *
|
||||
@ -25,6 +25,7 @@ from ..invocations.baseinvocation import (
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
)
|
||||
from .config import InvokeAISettings
|
||||
|
||||
|
||||
class EdgeConnection(BaseModel):
|
||||
@ -211,9 +212,10 @@ class CollectInvocation(BaseInvocation):
|
||||
InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
|
||||
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
class Graph(InvokeAISettings):
|
||||
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
|
||||
type: Literal["graph"] = "graph"
|
||||
|
||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
||||
description="The nodes in this graph", default_factory=dict
|
||||
@ -805,7 +807,7 @@ class GraphExecutionState(BaseModel):
|
||||
]
|
||||
}
|
||||
|
||||
def next(self) -> BaseInvocation | None:
|
||||
def next(self) -> Union[BaseInvocation, None]:
|
||||
"""Gets the next node ready to execute."""
|
||||
|
||||
# TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes
|
||||
@ -1154,7 +1156,7 @@ class ExposedNodeOutput(BaseModel):
|
||||
field: str = Field(description="The field name of the output")
|
||||
alias: str = Field(description="The alias of the output")
|
||||
|
||||
class LibraryGraph(BaseModel):
|
||||
class LibraryGraph(InvokeAISettings):
|
||||
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
|
||||
graph: Graph = Field(description="The graph")
|
||||
name: str = Field(description="The name of the graph")
|
||||
@ -1162,6 +1164,9 @@ class LibraryGraph(BaseModel):
|
||||
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
|
||||
exposed_outputs: list[ExposedNodeOutput] = Field(description="The outputs exposed by this graph", default_factory=list)
|
||||
|
||||
class Config:
|
||||
extra='allow'
|
||||
|
||||
@validator('exposed_inputs', 'exposed_outputs')
|
||||
def validate_exposed_aliases(cls, v):
|
||||
if len(v) != len(set(i.alias for i in v)):
|
||||
|
@ -3,7 +3,7 @@
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from queue import Queue
|
||||
|
||||
from typing import Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@ -22,7 +22,7 @@ class InvocationQueueABC(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(self, item: InvocationQueueItem | None) -> None:
|
||||
def put(self, item: Union[InvocationQueueItem, None]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -57,7 +57,7 @@ class MemoryInvocationQueue(InvocationQueueABC):
|
||||
|
||||
return item
|
||||
|
||||
def put(self, item: InvocationQueueItem | None) -> None:
|
||||
def put(self, item: Union[InvocationQueueItem, None]) -> None:
|
||||
self.__queue.put(item)
|
||||
|
||||
def cancel(self, graph_execution_state_id: str) -> None:
|
||||
|
@ -8,6 +8,7 @@ from .image_storage import ImageStorageBase
|
||||
from .restoration_services import RestorationServices
|
||||
from .invocation_queue import InvocationQueueABC
|
||||
from .item_storage import ItemStorageABC
|
||||
from .config import InvokeAISettings
|
||||
|
||||
class InvocationServices:
|
||||
"""Services that can be used by invocations"""
|
||||
@ -19,6 +20,7 @@ class InvocationServices:
|
||||
queue: InvocationQueueABC
|
||||
model_manager: ModelManager
|
||||
restoration: RestorationServices
|
||||
configuration: InvokeAISettings
|
||||
|
||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||
graph_library: ItemStorageABC["LibraryGraph"]
|
||||
@ -37,6 +39,7 @@ class InvocationServices:
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
processor: "InvocationProcessorABC",
|
||||
restoration: RestorationServices,
|
||||
configuration: InvokeAISettings=None,
|
||||
):
|
||||
self.model_manager = model_manager
|
||||
self.events = events
|
||||
@ -48,3 +51,4 @@ class InvocationServices:
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.processor = processor
|
||||
self.restoration = restoration
|
||||
self.configuration = configuration
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
from abc import ABC
|
||||
from threading import Event, Thread
|
||||
from typing import Union
|
||||
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .graph import Graph, GraphExecutionState
|
||||
@ -21,7 +22,7 @@ class Invoker:
|
||||
|
||||
def invoke(
|
||||
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
|
||||
) -> str | None:
|
||||
) -> Union[str, None]:
|
||||
"""Determines the next node to invoke and returns the id of the invoked node, or None if there are no nodes to execute"""
|
||||
|
||||
# Get the next invocation
|
||||
@ -44,7 +45,7 @@ class Invoker:
|
||||
|
||||
return invocation.id
|
||||
|
||||
def create_execution_state(self, graph: Graph | None = None) -> GraphExecutionState:
|
||||
def create_execution_state(self, graph: Union[Graph, None] = None) -> GraphExecutionState:
|
||||
"""Creates a new execution state for the given graph"""
|
||||
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
|
||||
self.services.graph_execution_manager.set(new_state)
|
||||
|
@ -4,7 +4,7 @@ import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -56,7 +56,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
if name in self.__cache:
|
||||
del self.__cache[name]
|
||||
|
||||
def __get_cache(self, name: str) -> torch.Tensor|None:
|
||||
def __get_cache(self, name: str) -> Union[torch.Tensor,None]:
|
||||
return None if name not in self.__cache else self.__cache[name]
|
||||
|
||||
def __set_cache(self, name: str, data: torch.Tensor):
|
||||
|
@ -7,21 +7,20 @@ from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
|
||||
import invokeai.version
|
||||
from .config import InvokeAISettings
|
||||
from ...backend import ModelManager
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
from ...backend import Globals
|
||||
|
||||
# TODO: Replace with an abstract class base ModelManagerBase
|
||||
def get_model_manager(config: Args) -> ModelManager:
|
||||
if not config.conf:
|
||||
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
||||
if not os.path.exists(config_file):
|
||||
def get_model_manager(config:InvokeAISettings) -> ModelManager:
|
||||
model_config = config.model_conf_path
|
||||
if not model_config.exists():
|
||||
report_model_error(
|
||||
config, FileNotFoundError(f"The file {config_file} could not be found.")
|
||||
config, FileNotFoundError(f"The file {model_config} could not be found.")
|
||||
)
|
||||
|
||||
print(f">> {invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
||||
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
||||
print(f'>> InvokeAI runtime directory is "{config.root_dir}"')
|
||||
|
||||
# these two lines prevent a horrible warning message from appearing
|
||||
# when the frozen CLIP tokenizer is imported
|
||||
@ -31,20 +30,7 @@ def get_model_manager(config: Args) -> ModelManager:
|
||||
import diffusers
|
||||
|
||||
diffusers.logging.set_verbosity_error()
|
||||
|
||||
# normalize the config directory relative to root
|
||||
if not os.path.isabs(config.conf):
|
||||
config.conf = os.path.normpath(os.path.join(Globals.root, config.conf))
|
||||
|
||||
if config.embeddings:
|
||||
if not os.path.isabs(config.embedding_path):
|
||||
embedding_path = os.path.normpath(
|
||||
os.path.join(Globals.root, config.embedding_path)
|
||||
)
|
||||
else:
|
||||
embedding_path = config.embedding_path
|
||||
else:
|
||||
embedding_path = None
|
||||
|
||||
# migrate legacy models
|
||||
ModelManager.migrate_models()
|
||||
@ -57,11 +43,11 @@ def get_model_manager(config: Args) -> ModelManager:
|
||||
else choose_precision(device)
|
||||
|
||||
model_manager = ModelManager(
|
||||
OmegaConf.load(config.conf),
|
||||
OmegaConf.load(model_config),
|
||||
precision=precision,
|
||||
device_type=device,
|
||||
max_loaded_models=config.max_loaded_models,
|
||||
embedding_path = Path(embedding_path),
|
||||
embedding_path = embedding_path,
|
||||
)
|
||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||
report_model_error(config, e)
|
||||
@ -71,12 +57,10 @@ def get_model_manager(config: Args) -> ModelManager:
|
||||
|
||||
# try to autoconvert new models
|
||||
# autoimport new .ckpt files
|
||||
if path := config.autoconvert:
|
||||
model_manager.autoconvert_weights(
|
||||
conf_path=config.conf,
|
||||
weights_directory=path,
|
||||
if config.autoconvert_path:
|
||||
model_manager.heuristic_import(
|
||||
config.autoconvert_path,
|
||||
)
|
||||
|
||||
return model_manager
|
||||
|
||||
def report_model_error(opt: Namespace, e: Exception):
|
||||
|
@ -16,6 +16,7 @@ import os.path as osp
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from pydantic import BaseSettings
|
||||
|
||||
Globals = Namespace()
|
||||
|
||||
@ -120,3 +121,15 @@ def global_cache_dir(subdir: Union[str, Path] = "") -> Path:
|
||||
return Path(home, subdir)
|
||||
else:
|
||||
return Path(Globals.root, "models", subdir)
|
||||
|
||||
def copy_conf_to_globals(conf: Union[dict,BaseSettings]):
|
||||
'''
|
||||
Given a dict or dict-like object, copy its keys and
|
||||
values into the Globals Namespace. This is a transitional
|
||||
workaround until we remove Globals entirely.
|
||||
'''
|
||||
if isinstance(conf,BaseSettings):
|
||||
conf = conf.dict()
|
||||
for key in conf.keys():
|
||||
if key is not None:
|
||||
setattr(Globals,key,conf[key])
|
||||
|
@ -8,7 +8,6 @@ from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, LibraryGraph, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_graph():
|
||||
g = Graph()
|
||||
|
@ -316,7 +316,7 @@ def test_graph_iterator_invalid_if_multiple_inputs():
|
||||
|
||||
def test_graph_iterator_invalid_if_input_not_list():
|
||||
g = Graph()
|
||||
n1 = TextToImageInvocation(id = "1", promopt = "Banana sushi")
|
||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
||||
n2 = IterateInvocation(id = "2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
|
99
tests/test_config.py
Normal file
99
tests/test_config.py
Normal file
@ -0,0 +1,99 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
|
||||
os.environ['INVOKEAI_ROOT']='/tmp'
|
||||
from invokeai.app.services.config import InvokeAIAppConfig, InvokeAISettings
|
||||
from invokeai.app.invocations.generate import TextToImageInvocation
|
||||
|
||||
init1 = OmegaConf.create(
|
||||
'''
|
||||
globals:
|
||||
nsfw_checker: False
|
||||
max_loaded_models: 5
|
||||
|
||||
history:
|
||||
count: 100
|
||||
|
||||
txt2img:
|
||||
steps: 18
|
||||
scheduler: k_heun
|
||||
width: 768
|
||||
|
||||
img2img:
|
||||
width: 1024
|
||||
height: 1024
|
||||
'''
|
||||
)
|
||||
|
||||
init2 = OmegaConf.create(
|
||||
'''
|
||||
globals:
|
||||
nsfw_checker: True
|
||||
max_loaded_models: 2
|
||||
|
||||
history:
|
||||
count: 10
|
||||
'''
|
||||
)
|
||||
|
||||
def test_use_init():
|
||||
# note that we explicitly set omegaconf dict and argv here
|
||||
# so that the values aren't read from ~invokeai/invokeai.yaml and
|
||||
# sys.argv respectively.
|
||||
conf1 = InvokeAIAppConfig(init1,[])
|
||||
assert conf1
|
||||
assert conf1.max_loaded_models==5
|
||||
assert not conf1.nsfw_checker
|
||||
|
||||
conf2 = InvokeAIAppConfig(init2,[])
|
||||
assert conf2
|
||||
assert conf2.nsfw_checker
|
||||
assert conf2.max_loaded_models==2
|
||||
assert not hasattr(conf2,'invalid_attribute')
|
||||
|
||||
|
||||
def test_argv_override():
|
||||
conf = InvokeAIAppConfig(init1,['--nsfw_checker','--max_loaded=10'])
|
||||
assert conf.nsfw_checker
|
||||
assert conf.max_loaded_models==10
|
||||
assert conf.outdir==Path('outputs') # this is the default
|
||||
|
||||
def test_env_override():
|
||||
# argv overrides
|
||||
conf = InvokeAIAppConfig(conf=init1,argv=['--max_loaded=10'])
|
||||
assert conf.nsfw_checker==False
|
||||
|
||||
os.environ['INVOKEAI_globals_nsfw_checker'] = 'True'
|
||||
conf = InvokeAIAppConfig(conf=init1,argv=['--max_loaded=10'])
|
||||
assert conf.nsfw_checker==True
|
||||
|
||||
conf = InvokeAIAppConfig(conf=init1,argv=['--no-nsfw_checker','--max_loaded=10'])
|
||||
assert conf.nsfw_checker==False
|
||||
|
||||
conf = InvokeAIAppConfig(conf=init1,argv=[],max_loaded_models=20)
|
||||
assert conf.max_loaded_models==20
|
||||
|
||||
# have to comment this one out because of a race condition in setting same
|
||||
# environment variable in the CI test environment
|
||||
# assert conf.root==Path('/tmp')
|
||||
|
||||
def test_invocation():
|
||||
InvokeAISettings.initconf=init1
|
||||
invocation = TextToImageInvocation(id='foobar')
|
||||
assert invocation.steps==18
|
||||
assert invocation.scheduler=='k_heun'
|
||||
assert invocation.height==512 # default
|
||||
|
||||
invocation = TextToImageInvocation(id='foobar2',steps=30)
|
||||
assert invocation.steps==30
|
||||
|
||||
def test_type_coercion():
|
||||
conf = InvokeAIAppConfig(argv=['--root=/tmp/foobar'])
|
||||
assert conf.root==Path('/tmp/foobar')
|
||||
assert isinstance(conf.root,Path)
|
||||
conf = InvokeAIAppConfig(argv=['--root=/tmp/foobar'],root='/tmp/different')
|
||||
assert conf.root==Path('/tmp/different')
|
||||
assert isinstance(conf.root,Path)
|
Reference in New Issue
Block a user