Do less stuff on import of api_app.py. Instead, call functions imperatively.

This commit is contained in:
Ryan Dick 2024-04-25 17:45:54 -04:00
parent caa7c0f2bd
commit 0bdf7f5726
2 changed files with 205 additions and 186 deletions

View File

@ -2,6 +2,7 @@ import asyncio
import logging
import mimetypes
import socket
import time
from contextlib import asynccontextmanager
from inspect import signature
from pathlib import Path
@ -20,13 +21,10 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.json_schema import models_json_schema
from torch.backends.mps import is_available as is_mps_available
# for PyCharm:
# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.util.devices import TorchDevice
@ -44,84 +42,115 @@ from .api.routers import (
workflows,
)
from .api.sockets import SocketIO
from .invocations.baseinvocation import (
BaseInvocation,
UIConfigBase,
)
from .invocations.baseinvocation import BaseInvocation, UIConfigBase
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
app_config = get_config()
if is_mps_available():
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
logger = InvokeAILogger.get_logger(config=app_config)
# fix for windows mimetypes registry entries being borked
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
torch_device_name = TorchDevice.get_torch_device_name()
logger.info(f"Using torch device: {torch_device_name}")
@asynccontextmanager
async def lifespan(app: FastAPI):
# TODO(ryand): Search for imports from api_app.py in the rest of the codebase and make sure I didn't break any of them.
def build_app(app_config: InvokeAIAppConfig, logger: logging.Logger) -> FastAPI:
@asynccontextmanager
async def lifespan(app: FastAPI):
# Add startup event to load dependencies
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
yield
# Shut down threads
ApiDependencies.shutdown()
# Create the app
# TODO: create this all in a method so configuration/etc. can be passed in?
app = FastAPI(
app = FastAPI(
title="Invoke - Community Edition",
docs_url=None,
redoc_url=None,
separate_input_output_schemas=False,
lifespan=lifespan,
)
)
# Add event handler
event_handler_id: int = id(app)
app.add_middleware(
# Add event handler
event_handler_id: int = id(app)
app.add_middleware(
EventHandlerASGIMiddleware,
handlers=[local_handler], # TODO: consider doing this in services to support different configurations
middleware_id=event_handler_id,
)
)
socket_io = SocketIO(app)
socket_io = SocketIO(app)
app.add_middleware(
app.add_middleware(
CORSMiddleware,
allow_origins=app_config.allow_origins,
allow_credentials=app_config.allow_credentials,
allow_methods=app_config.allow_methods,
allow_headers=app_config.allow_headers,
)
)
app.add_middleware(GZipMiddleware, minimum_size=1000)
app.add_middleware(GZipMiddleware, minimum_size=1000)
# Include all routers
app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(model_manager.model_manager_router, prefix="/api")
app.include_router(download_queue.download_queue_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api")
app.include_router(board_images.board_images_router, prefix="/api")
app.include_router(app_info.app_router, prefix="/api")
app.include_router(session_queue.session_queue_router, prefix="/api")
app.include_router(workflows.workflows_router, prefix="/api")
add_custom_openapi(app)
@app.get("/docs", include_in_schema=False)
def overridden_swagger() -> HTMLResponse:
return get_swagger_ui_html(
openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string
title=f"{app.title} - Swagger UI",
swagger_favicon_url="static/docs/invoke-favicon-docs.svg",
)
@app.get("/redoc", include_in_schema=False)
def overridden_redoc() -> HTMLResponse:
return get_redoc_html(
openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string
title=f"{app.title} - Redoc",
redoc_favicon_url="static/docs/invoke-favicon-docs.svg",
)
web_root_path = Path(list(web_dir.__path__)[0])
try:
app.mount("/", NoCacheStaticFiles(directory=Path(web_root_path, "dist"), html=True), name="ui")
except RuntimeError:
logger.warn(f"No UI found at {web_root_path}/dist, skipping UI mount")
app.mount(
"/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static"
) # docs favicon is in here
return app
# Include all routers
app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(model_manager.model_manager_router, prefix="/api")
app.include_router(download_queue.download_queue_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api")
app.include_router(board_images.board_images_router, prefix="/api")
app.include_router(app_info.app_router, prefix="/api")
app.include_router(session_queue.session_queue_router, prefix="/api")
app.include_router(workflows.workflows_router, prefix="/api")
def apply_monkeypatches() -> None:
# TODO(ryand): Don't monkeypatch on import!
if is_mps_available():
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
def custom_openapi() -> dict[str, Any]:
def fix_mimetypes():
# fix for windows mimetypes registry entries being borked
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
def add_custom_openapi(app: FastAPI) -> None:
"""Add a custom .openapi() method to the FastAPI app.
This is done based on the guidance here:
https://fastapi.tiangolo.com/how-to/extending-openapi/#normal-fastapi
"""
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
def custom_openapi() -> dict[str, Any]:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
@ -195,37 +224,7 @@ def custom_openapi() -> dict[str, Any]:
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
@app.get("/docs", include_in_schema=False)
def overridden_swagger() -> HTMLResponse:
return get_swagger_ui_html(
openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string
title=f"{app.title} - Swagger UI",
swagger_favicon_url="static/docs/invoke-favicon-docs.svg",
)
@app.get("/redoc", include_in_schema=False)
def overridden_redoc() -> HTMLResponse:
return get_redoc_html(
openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string
title=f"{app.title} - Redoc",
redoc_favicon_url="static/docs/invoke-favicon-docs.svg",
)
web_root_path = Path(list(web_dir.__path__)[0])
try:
app.mount("/", NoCacheStaticFiles(directory=Path(web_root_path, "dist"), html=True), name="ui")
except RuntimeError:
logger.warn(f"No UI found at {web_root_path}/dist, skipping UI mount")
app.mount(
"/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static"
) # docs favicon is in here
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
def check_cudnn(logger: logging.Logger) -> None:
@ -244,8 +243,7 @@ def check_cudnn(logger: logging.Logger) -> None:
)
def invoke_api() -> None:
def find_port(port: int) -> int:
def find_port(port: int) -> int:
"""Find a port not in use starting at given port"""
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
# https://github.com/WaylonWalker
@ -255,7 +253,8 @@ def invoke_api() -> None:
else:
return port
if app_config.dev_reload:
def init_dev_reload(logger: logging.Logger) -> None:
try:
import jurigged
except ImportError as e:
@ -266,12 +265,29 @@ def invoke_api() -> None:
else:
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
def invoke_api() -> None:
start = time.time()
app_config = get_config()
logger = InvokeAILogger.get_logger(config=app_config)
apply_monkeypatches()
fix_mimetypes()
if app_config.dev_reload:
init_dev_reload(logger)
port = find_port(app_config.port)
if port != app_config.port:
logger.warn(f"Port {app_config.port} in use, using port {port}")
check_cudnn(logger)
torch_device_name = TorchDevice.get_torch_device_name()
logger.info(f"Using torch device: {torch_device_name}")
app = build_app(app_config, logger)
# Start our own event loop for eventing usage
loop = asyncio.new_event_loop()
config = uvicorn.Config(
@ -291,6 +307,7 @@ def invoke_api() -> None:
log.handlers.clear()
for ch in logger.handlers:
log.addHandler(ch)
logger.info(f"API started in {time.time() - start:.2f} seconds")
loop.run_until_complete(server.serve())

View File

@ -7,10 +7,12 @@ import os
from invokeai.app.run_app import run_app
logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage())
def main():
logging.getLogger("xformers").addFilter(
lambda record: "A matching Triton is not available" not in record.getMessage()
)
# Change working directory to the repo root
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
run_app()