use InvokeAISettings for app-wide configuration

This commit is contained in:
Lincoln Stein 2023-05-03 22:30:30 -04:00
parent 5e8c97f1ba
commit 90054ddf0d
9 changed files with 80 additions and 86 deletions

View File

@ -7,7 +7,7 @@ from typing import types
from ..services.default_graphs import create_system_graphs
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ...backend import Globals
from ..services.config import InvokeAIWebConfig
from ..services.model_manager_initializer import get_model_manager
from ..services.restoration_services import RestorationServices
from ..services.graph import GraphExecutionState, LibraryGraph
@ -42,17 +42,8 @@ class ApiDependencies:
invoker: Invoker = None
@staticmethod
def initialize(config, event_handler_id: int, logger: types.ModuleType=logger):
Globals.try_patchmatch = config.patchmatch
Globals.always_use_cpu = config.always_use_cpu
Globals.internet_available = config.internet_available and check_internet()
Globals.disable_xformers = not config.xformers
Globals.ckpt_convert = config.ckpt_convert
# TO DO: Use the config to select the logger rather than use the default
# invokeai logging module
logger.info(f"Internet connectivity is {Globals.internet_available}")
logger.info(f"Internet connectivity is {config.internet_available}")
events = FastAPIEventService(event_handler_id)
@ -72,7 +63,6 @@ class ApiDependencies:
services = InvocationServices(
model_manager=get_model_manager(config,logger),
events=events,
logger=logger,
latents=latents,
images=images,
metadata=metadata,
@ -85,6 +75,8 @@ class ApiDependencies:
),
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config,logger),
configuration=config,
logger=logger,
)
create_system_graphs(services.graph_library)

View File

@ -13,11 +13,11 @@ from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.schema import schema
from ..backend import Args
from .api.dependencies import ApiDependencies
from .api.routers import images, sessions, models
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation
from .services.config import InvokeAIWebConfig
# Create the app
# TODO: create this all in a method so configuration/etc. can be passed in?
@ -33,30 +33,14 @@ app.add_middleware(
middleware_id=event_handler_id,
)
# Add CORS
# TODO: use configuration for this
origins = []
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
socket_io = SocketIO(app)
config = {}
web_config = {}
# Add startup event to load dependencies
@app.on_event("startup")
async def startup_event():
config = Args()
config.parse_args()
ApiDependencies.initialize(
config=config, event_handler_id=event_handler_id, logger=logger
config=web_config, event_handler_id=event_handler_id, logger=logger
)
@ -146,15 +130,23 @@ def overridden_redoc():
def invoke_api():
# parse command-line settings, environment and the init file
# (this is a module global)
global web_config
web_config = InvokeAIWebConfig()
app.add_middleware(
CORSMiddleware,
allow_origins=web_config.allow_origins,
allow_credentials=web_config.allow_credentials,
allow_methods=web_config.allow_methods,
allow_headers=web_config.allow_headers,
)
# Start our own event loop for eventing usage
# TODO: determine if there's a better way to do this
loop = asyncio.new_event_loop()
config = uvicorn.Config(app=app, host="0.0.0.0", port=9090, loop=loop)
config = uvicorn.Config(app=app, host=web_config.host, port=web_config.port, loop=loop)
# Use access_log to turn off logging
server = uvicorn.Server(config)
loop.run_until_complete(server.serve())
if __name__ == "__main__":
invoke_api()

View File

@ -285,3 +285,19 @@ class DrawExecutionGraphCommand(BaseCommand):
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
plt.axis("off")
plt.show()
class SortedHelpFormatter(argparse.HelpFormatter):
def _iter_indented_subactions(self, action):
try:
get_subactions = action._get_subactions
except AttributeError:
pass
else:
self._indent()
if isinstance(action, argparse._SubParsersAction):
for subaction in sorted(get_subactions(), key=lambda x: x.dest):
yield subaction
else:
for subaction in get_subactions():
yield subaction
self._dedent()

View File

@ -11,9 +11,10 @@ from pathlib import Path
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
import invokeai.backend.util.logging as logger
from ...backend import ModelManager, Globals
from ...backend import ModelManager
from ..invocations.baseinvocation import BaseInvocation
from .commands import BaseCommand
from ..services.invocation_services import InvocationServices
# singleton object, class variable
completer = None
@ -131,13 +132,13 @@ class Completer(object):
readline.redisplay()
self.linebuffer = None
def set_autocompleter(model_manager: ModelManager) -> Completer:
def set_autocompleter(services: InvocationServices) -> Completer:
global completer
if completer:
return completer
completer = Completer(model_manager)
completer = Completer(services.model_manager)
readline.set_completer(completer.complete)
# pyreadline3 does not have a set_auto_history() method
@ -153,7 +154,7 @@ def set_autocompleter(model_manager: ModelManager) -> Completer:
readline.parse_and_bind("set skip-completed-text on")
readline.parse_and_bind("set show-all-if-ambiguous on")
histfile = Path(Globals.root, ".invoke_history")
histfile = Path(services.configuration.root_dir / ".invoke_history")
try:
readline.read_history_file(histfile)
readline.set_history_length(1000)

View File

@ -19,8 +19,7 @@ from invokeai.app.services.metadata import PngMetadataService
from .services.default_graphs import create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..backend import Args
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, SortedHelpFormatter
from .cli.completer import set_autocompleter
from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase
@ -34,7 +33,7 @@ from .services.invocation_services import InvocationServices
from .services.invoker import Invoker
from .services.processor import DefaultInvocationProcessor
from .services.sqlite import SqliteItemStorage
from .services.config import InvokeAIAppConfig
class CliCommand(BaseModel):
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
@ -64,7 +63,7 @@ def add_invocation_args(command_parser):
def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
# Create invocation parser
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(formatter_class=SortedHelpFormatter)
def exit(*args, **kwargs):
raise InvalidArgs
@ -189,24 +188,13 @@ def invoke_all(context: CliContext):
def invoke_cli():
config = Args()
config.parse_args()
config = InvokeAIAppConfig()
model_manager = get_model_manager(config,logger=logger)
# This initializes the autocompleter and returns it.
# Currently nothing is done with the returned Completer
# object, but the object can be used to change autocompletion
# behavior on the fly, if desired.
set_autocompleter(model_manager)
events = EventServiceBase()
output_folder = config.output_path
metadata = PngMetadataService()
output_folder = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../../../outputs")
)
# TODO: build a file/path manager?
db_location = os.path.join(output_folder, "invokeai.db")
@ -226,6 +214,7 @@ def invoke_cli():
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config,logger=logger),
logger=logger,
configuration=config,
)
system_graphs = create_system_graphs(services.graph_library)
@ -241,6 +230,7 @@ def invoke_cli():
# print(services.session_manager.list())
context = CliContext(invoker, session, parser)
set_autocompleter(services)
while True:
try:

View File

@ -135,6 +135,7 @@ class GraphInvocationOutput(BaseInvocationOutput):
# TODO: Fill this out and move to invocations
class GraphInvocation(BaseInvocation):
"""Execute a graph"""
type: Literal["graph"] = "graph"
# TODO: figure out how to create a default here
@ -162,6 +163,7 @@ class IterateInvocationOutput(BaseInvocationOutput):
# TODO: Fill this out and move to invocations
class IterateInvocation(BaseInvocation):
"""Iterates over a list of items"""
type: Literal["iterate"] = "iterate"
collection: list[Any] = Field(

View File

@ -10,6 +10,7 @@ from .image_storage import ImageStorageBase
from .restoration_services import RestorationServices
from .invocation_queue import InvocationQueueABC
from .item_storage import ItemStorageABC
from .config import InvokeAISettings
class InvocationServices:
"""Services that can be used by invocations"""
@ -21,7 +22,8 @@ class InvocationServices:
queue: InvocationQueueABC
model_manager: ModelManager
restoration: RestorationServices
configuration: InvokeAISettings
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
graph_library: ItemStorageABC["LibraryGraph"]
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
@ -40,6 +42,7 @@ class InvocationServices:
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: "InvocationProcessorABC",
restoration: RestorationServices,
configuration: InvokeAISettings=None,
):
self.model_manager = model_manager
self.events = events
@ -52,3 +55,4 @@ class InvocationServices:
self.graph_execution_manager = graph_execution_manager
self.processor = processor
self.restoration = restoration
self.configuration = configuration

View File

@ -8,21 +8,20 @@ from pathlib import Path
from typing import types
import invokeai.version
from .config import InvokeAISettings
from ...backend import ModelManager
from ...backend.util import choose_precision, choose_torch_device
from ...backend import Globals
# TODO: Replace with an abstract class base ModelManagerBase
def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
if not config.conf:
config_file = os.path.join(Globals.root, "configs", "models.yaml")
if not os.path.exists(config_file):
report_model_error(
config, FileNotFoundError(f"The file {config_file} could not be found."), logger
)
model_config = config.model_conf_path
if not model_config.exists():
report_model_error(
config, FileNotFoundError(f"The file {model_config} could not be found."), logger
)
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
logger.info(f'InvokeAI runtime directory is "{config.root}"')
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
@ -32,20 +31,7 @@ def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
import diffusers
diffusers.logging.set_verbosity_error()
# normalize the config directory relative to root
if not os.path.isabs(config.conf):
config.conf = os.path.normpath(os.path.join(Globals.root, config.conf))
if config.embeddings:
if not os.path.isabs(config.embedding_path):
embedding_path = os.path.normpath(
os.path.join(Globals.root, config.embedding_path)
)
else:
embedding_path = config.embedding_path
else:
embedding_path = None
embedding_path = config.embedding_path
# migrate legacy models
ModelManager.migrate_models()
@ -62,7 +48,7 @@ def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
precision=precision,
device_type=device,
max_loaded_models=config.max_loaded_models,
embedding_path = Path(embedding_path),
embedding_path = embedding_path,
logger = logger,
)
except (FileNotFoundError, TypeError, AssertionError) as e:
@ -73,12 +59,10 @@ def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
# try to autoconvert new models
# autoimport new .ckpt files
if path := config.autoconvert:
model_manager.autoconvert_weights(
conf_path=config.conf,
weights_directory=path,
if config.autoconvert_path:
model_manager.heuristic_import(
config.autoconvert_path,
)
logger.info('Model manager initialized')
return model_manager
def report_model_error(opt: Namespace, e: Exception, logger: types.ModuleType):

View File

@ -16,6 +16,7 @@ import os.path as osp
from argparse import Namespace
from pathlib import Path
from typing import Union
from pydantic import BaseSettings
Globals = Namespace()
@ -120,3 +121,15 @@ def global_cache_dir(subdir: Union[str, Path] = "") -> Path:
return Path(home, subdir)
else:
return Path(Globals.root, "models", subdir)
def copy_conf_to_globals(conf: Union[dict,BaseSettings]):
'''
Given a dict or dict-like object, copy its keys and
values into the Globals Namespace. This is a transitional
workaround until we remove Globals entirely.
'''
if isinstance(conf,BaseSettings):
conf = conf.dict()
for key in conf.keys():
if key is not None:
setattr(Globals,key,conf[key])