From 0bdf7f57261c053b3007f78220ac17f911d42889 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 25 Apr 2024 17:45:54 -0400 Subject: [PATCH] Do less stuff on import of api_app.py. Instead, call functions imperatively. --- invokeai/app/api_app.py | 385 +++++++++++++++++++++------------------- scripts/invokeai-web.py | 6 +- 2 files changed, 205 insertions(+), 186 deletions(-) diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index ceaeb95147..71e61e074d 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -2,6 +2,7 @@ import asyncio import logging import mimetypes import socket +import time from contextlib import asynccontextmanager from inspect import signature from pathlib import Path @@ -20,13 +21,10 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware from pydantic.json_schema import models_json_schema from torch.backends.mps import is_available as is_mps_available -# for PyCharm: -# noinspection PyUnresolvedReferences -import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) import invokeai.frontend.web as web_dir from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles from invokeai.app.invocations.model import ModelIdentifierField -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.backend.util.devices import TorchDevice @@ -44,188 +42,189 @@ from .api.routers import ( workflows, ) from .api.sockets import SocketIO -from .invocations.baseinvocation import ( - BaseInvocation, - UIConfigBase, -) +from .invocations.baseinvocation import BaseInvocation, UIConfigBase from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra -app_config = get_config() + +# TODO(ryand): Search for imports from api_app.py in the rest of the codebase and make sure I didn't break any of them. +def build_app(app_config: InvokeAIAppConfig, logger: logging.Logger) -> FastAPI: + @asynccontextmanager + async def lifespan(app: FastAPI): + # Add startup event to load dependencies + ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger) + yield + # Shut down threads + ApiDependencies.shutdown() + + app = FastAPI( + title="Invoke - Community Edition", + docs_url=None, + redoc_url=None, + separate_input_output_schemas=False, + lifespan=lifespan, + ) + + # Add event handler + event_handler_id: int = id(app) + app.add_middleware( + EventHandlerASGIMiddleware, + handlers=[local_handler], # TODO: consider doing this in services to support different configurations + middleware_id=event_handler_id, + ) + + socket_io = SocketIO(app) + + app.add_middleware( + CORSMiddleware, + allow_origins=app_config.allow_origins, + allow_credentials=app_config.allow_credentials, + allow_methods=app_config.allow_methods, + allow_headers=app_config.allow_headers, + ) + + app.add_middleware(GZipMiddleware, minimum_size=1000) + + # Include all routers + app.include_router(utilities.utilities_router, prefix="/api") + app.include_router(model_manager.model_manager_router, prefix="/api") + app.include_router(download_queue.download_queue_router, prefix="/api") + app.include_router(images.images_router, prefix="/api") + app.include_router(boards.boards_router, prefix="/api") + app.include_router(board_images.board_images_router, prefix="/api") + app.include_router(app_info.app_router, prefix="/api") + app.include_router(session_queue.session_queue_router, prefix="/api") + app.include_router(workflows.workflows_router, prefix="/api") + + add_custom_openapi(app) + + @app.get("/docs", include_in_schema=False) + def overridden_swagger() -> HTMLResponse: + return get_swagger_ui_html( + openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string + title=f"{app.title} - Swagger UI", + swagger_favicon_url="static/docs/invoke-favicon-docs.svg", + ) + + @app.get("/redoc", include_in_schema=False) + def overridden_redoc() -> HTMLResponse: + return get_redoc_html( + openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string + title=f"{app.title} - Redoc", + redoc_favicon_url="static/docs/invoke-favicon-docs.svg", + ) + + web_root_path = Path(list(web_dir.__path__)[0]) + + try: + app.mount("/", NoCacheStaticFiles(directory=Path(web_root_path, "dist"), html=True), name="ui") + except RuntimeError: + logger.warn(f"No UI found at {web_root_path}/dist, skipping UI mount") + + app.mount( + "/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static" + ) # docs favicon is in here + + return app -if is_mps_available(): - import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import) +def apply_monkeypatches() -> None: + # TODO(ryand): Don't monkeypatch on import! + if is_mps_available(): + import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import) + import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) -logger = InvokeAILogger.get_logger(config=app_config) -# fix for windows mimetypes registry entries being borked -# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352 -mimetypes.add_type("application/javascript", ".js") -mimetypes.add_type("text/css", ".css") - -torch_device_name = TorchDevice.get_torch_device_name() -logger.info(f"Using torch device: {torch_device_name}") +def fix_mimetypes(): + # fix for windows mimetypes registry entries being borked + # see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352 + mimetypes.add_type("application/javascript", ".js") + mimetypes.add_type("text/css", ".css") -@asynccontextmanager -async def lifespan(app: FastAPI): - # Add startup event to load dependencies - ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger) - yield - # Shut down threads - ApiDependencies.shutdown() +def add_custom_openapi(app: FastAPI) -> None: + """Add a custom .openapi() method to the FastAPI app. + This is done based on the guidance here: + https://fastapi.tiangolo.com/how-to/extending-openapi/#normal-fastapi + """ -# Create the app -# TODO: create this all in a method so configuration/etc. can be passed in? -app = FastAPI( - title="Invoke - Community Edition", - docs_url=None, - redoc_url=None, - separate_input_output_schemas=False, - lifespan=lifespan, -) + # Build a custom OpenAPI to include all outputs + # TODO: can outputs be included on metadata of invocation schemas somehow? + def custom_openapi() -> dict[str, Any]: + if app.openapi_schema: + return app.openapi_schema + openapi_schema = get_openapi( + title=app.title, + description="An API for invoking AI image operations", + version="1.0.0", + routes=app.routes, + separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/ + ) -# Add event handler -event_handler_id: int = id(app) -app.add_middleware( - EventHandlerASGIMiddleware, - handlers=[local_handler], # TODO: consider doing this in services to support different configurations - middleware_id=event_handler_id, -) + # Add all outputs + all_invocations = BaseInvocation.get_invocations() + output_types = set() + output_type_titles = {} + for invoker in all_invocations: + output_type = signature(invoker.invoke).return_annotation + output_types.add(output_type) -socket_io = SocketIO(app) + output_schemas = models_json_schema( + models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}" + ) + for schema_key, output_schema in output_schemas[1]["$defs"].items(): + # TODO: note that we assume the schema_key here is the TYPE.__name__ + # This could break in some cases, figure out a better way to do it + output_type_titles[schema_key] = output_schema["title"] + openapi_schema["components"]["schemas"][schema_key] = output_schema + openapi_schema["components"]["schemas"][schema_key]["class"] = "output" -app.add_middleware( - CORSMiddleware, - allow_origins=app_config.allow_origins, - allow_credentials=app_config.allow_credentials, - allow_methods=app_config.allow_methods, - allow_headers=app_config.allow_headers, -) + # Some models don't end up in the schemas as standalone definitions + additional_schemas = models_json_schema( + [ + (UIConfigBase, "serialization"), + (InputFieldJSONSchemaExtra, "serialization"), + (OutputFieldJSONSchemaExtra, "serialization"), + (ModelIdentifierField, "serialization"), + (ProgressImage, "serialization"), + ], + ref_template="#/components/schemas/{model}", + ) + for schema_key, schema_json in additional_schemas[1]["$defs"].items(): + openapi_schema["components"]["schemas"][schema_key] = schema_json -app.add_middleware(GZipMiddleware, minimum_size=1000) + # Add a reference to the output type to additionalProperties of the invoker schema + for invoker in all_invocations: + invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute + output_type = signature(obj=invoker.invoke).return_annotation + output_type_title = output_type_titles[output_type.__name__] + invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"] + outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"} + invoker_schema["output"] = outputs_ref + invoker_schema["class"] = "invocation" + # This code no longer seems to be necessary? + # Leave it here just in case + # + # from invokeai.backend.model_manager import get_model_config_formats + # formats = get_model_config_formats() + # for model_config_name, enum_set in formats.items(): -# Include all routers -app.include_router(utilities.utilities_router, prefix="/api") -app.include_router(model_manager.model_manager_router, prefix="/api") -app.include_router(download_queue.download_queue_router, prefix="/api") -app.include_router(images.images_router, prefix="/api") -app.include_router(boards.boards_router, prefix="/api") -app.include_router(board_images.board_images_router, prefix="/api") -app.include_router(app_info.app_router, prefix="/api") -app.include_router(session_queue.session_queue_router, prefix="/api") -app.include_router(workflows.workflows_router, prefix="/api") + # if model_config_name in openapi_schema["components"]["schemas"]: + # # print(f"Config with name {name} already defined") + # continue + # openapi_schema["components"]["schemas"][model_config_name] = { + # "title": model_config_name, + # "description": "An enumeration.", + # "type": "string", + # "enum": [v.value for v in enum_set], + # } -# Build a custom OpenAPI to include all outputs -# TODO: can outputs be included on metadata of invocation schemas somehow? -def custom_openapi() -> dict[str, Any]: - if app.openapi_schema: + app.openapi_schema = openapi_schema return app.openapi_schema - openapi_schema = get_openapi( - title=app.title, - description="An API for invoking AI image operations", - version="1.0.0", - routes=app.routes, - separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/ - ) - # Add all outputs - all_invocations = BaseInvocation.get_invocations() - output_types = set() - output_type_titles = {} - for invoker in all_invocations: - output_type = signature(invoker.invoke).return_annotation - output_types.add(output_type) - - output_schemas = models_json_schema( - models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}" - ) - for schema_key, output_schema in output_schemas[1]["$defs"].items(): - # TODO: note that we assume the schema_key here is the TYPE.__name__ - # This could break in some cases, figure out a better way to do it - output_type_titles[schema_key] = output_schema["title"] - openapi_schema["components"]["schemas"][schema_key] = output_schema - openapi_schema["components"]["schemas"][schema_key]["class"] = "output" - - # Some models don't end up in the schemas as standalone definitions - additional_schemas = models_json_schema( - [ - (UIConfigBase, "serialization"), - (InputFieldJSONSchemaExtra, "serialization"), - (OutputFieldJSONSchemaExtra, "serialization"), - (ModelIdentifierField, "serialization"), - (ProgressImage, "serialization"), - ], - ref_template="#/components/schemas/{model}", - ) - for schema_key, schema_json in additional_schemas[1]["$defs"].items(): - openapi_schema["components"]["schemas"][schema_key] = schema_json - - # Add a reference to the output type to additionalProperties of the invoker schema - for invoker in all_invocations: - invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute - output_type = signature(obj=invoker.invoke).return_annotation - output_type_title = output_type_titles[output_type.__name__] - invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"] - outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"} - invoker_schema["output"] = outputs_ref - invoker_schema["class"] = "invocation" - - # This code no longer seems to be necessary? - # Leave it here just in case - # - # from invokeai.backend.model_manager import get_model_config_formats - # formats = get_model_config_formats() - # for model_config_name, enum_set in formats.items(): - - # if model_config_name in openapi_schema["components"]["schemas"]: - # # print(f"Config with name {name} already defined") - # continue - - # openapi_schema["components"]["schemas"][model_config_name] = { - # "title": model_config_name, - # "description": "An enumeration.", - # "type": "string", - # "enum": [v.value for v in enum_set], - # } - - app.openapi_schema = openapi_schema - return app.openapi_schema - - -app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment - - -@app.get("/docs", include_in_schema=False) -def overridden_swagger() -> HTMLResponse: - return get_swagger_ui_html( - openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string - title=f"{app.title} - Swagger UI", - swagger_favicon_url="static/docs/invoke-favicon-docs.svg", - ) - - -@app.get("/redoc", include_in_schema=False) -def overridden_redoc() -> HTMLResponse: - return get_redoc_html( - openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string - title=f"{app.title} - Redoc", - redoc_favicon_url="static/docs/invoke-favicon-docs.svg", - ) - - -web_root_path = Path(list(web_dir.__path__)[0]) - -try: - app.mount("/", NoCacheStaticFiles(directory=Path(web_root_path, "dist"), html=True), name="ui") -except RuntimeError: - logger.warn(f"No UI found at {web_root_path}/dist, skipping UI mount") -app.mount( - "/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static" -) # docs favicon is in here + app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment def check_cudnn(logger: logging.Logger) -> None: @@ -244,27 +243,39 @@ def check_cudnn(logger: logging.Logger) -> None: ) +def find_port(port: int) -> int: + """Find a port not in use starting at given port""" + # Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon! + # https://github.com/WaylonWalker + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + if s.connect_ex(("localhost", port)) == 0: + return find_port(port=port + 1) + else: + return port + + +def init_dev_reload(logger: logging.Logger) -> None: + try: + import jurigged + except ImportError as e: + logger.error( + 'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.', + exc_info=e, + ) + else: + jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info) + + def invoke_api() -> None: - def find_port(port: int) -> int: - """Find a port not in use starting at given port""" - # Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon! - # https://github.com/WaylonWalker - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - if s.connect_ex(("localhost", port)) == 0: - return find_port(port=port + 1) - else: - return port + start = time.time() + app_config = get_config() + logger = InvokeAILogger.get_logger(config=app_config) + + apply_monkeypatches() + fix_mimetypes() if app_config.dev_reload: - try: - import jurigged - except ImportError as e: - logger.error( - 'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.', - exc_info=e, - ) - else: - jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info) + init_dev_reload(logger) port = find_port(app_config.port) if port != app_config.port: @@ -272,6 +283,11 @@ def invoke_api() -> None: check_cudnn(logger) + torch_device_name = TorchDevice.get_torch_device_name() + logger.info(f"Using torch device: {torch_device_name}") + + app = build_app(app_config, logger) + # Start our own event loop for eventing usage loop = asyncio.new_event_loop() config = uvicorn.Config( @@ -291,6 +307,7 @@ def invoke_api() -> None: log.handlers.clear() for ch in logger.handlers: log.addHandler(ch) + logger.info(f"API started in {time.time() - start:.2f} seconds") loop.run_until_complete(server.serve()) diff --git a/scripts/invokeai-web.py b/scripts/invokeai-web.py index 691e58f7d1..ab979dd867 100755 --- a/scripts/invokeai-web.py +++ b/scripts/invokeai-web.py @@ -7,10 +7,12 @@ import os from invokeai.app.run_app import run_app -logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage()) - def main(): + logging.getLogger("xformers").addFilter( + lambda record: "A matching Triton is not available" not in record.getMessage() + ) + # Change working directory to the repo root os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) run_app()