fix(api): flesh out types for api_app.py

This commit is contained in:
psychedelicious 2023-10-18 13:43:01 +11:00
parent e4c45012f4
commit da403ba04c

View File

@ -1,3 +1,5 @@
from typing import Any
from fastapi.responses import HTMLResponse
from .services.config import InvokeAIAppConfig from .services.config import InvokeAIAppConfig
# parse_args() must be called before any other imports. if it is not called first, consumers of the config # parse_args() must be called before any other imports. if it is not called first, consumers of the config
@ -13,7 +15,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
from inspect import signature from inspect import signature
from pathlib import Path from pathlib import Path
import torch
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -26,6 +27,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
from pydantic.json_schema import models_json_schema from pydantic.json_schema import models_json_schema
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
# for PyCharm:
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir import invokeai.frontend.web as web_dir
@ -36,16 +38,15 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
from .api.routers import app_info, board_images, boards, images, models, session_queue, sessions, utilities from .api.routers import app_info, board_images, boards, images, models, session_queue, sessions, utilities
from .api.sockets import SocketIO from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
from torch.backends.mps import is_available as is_mps_available
if torch.backends.mps.is_available(): if is_mps_available():
# noinspection PyUnresolvedReferences
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import) import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
app_config = InvokeAIAppConfig.get_config() app_config = InvokeAIAppConfig.get_config()
app_config.parse_args() app_config.parse_args()
logger = InvokeAILogger.get_logger(config=app_config) logger = InvokeAILogger.get_logger(config=app_config)
# fix for windows mimetypes registry entries being borked # fix for windows mimetypes registry entries being borked
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352 # see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
mimetypes.add_type("application/javascript", ".js") mimetypes.add_type("application/javascript", ".js")
@ -78,13 +79,13 @@ app.add_middleware(GZipMiddleware, minimum_size=1000)
# Add startup event to load dependencies # Add startup event to load dependencies
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event() -> None:
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger) ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
# Shut down threads # Shut down threads
@app.on_event("shutdown") @app.on_event("shutdown")
async def shutdown_event(): async def shutdown_event() -> None:
ApiDependencies.shutdown() ApiDependencies.shutdown()
@ -108,7 +109,7 @@ app.include_router(session_queue.session_queue_router, prefix="/api")
# Build a custom OpenAPI to include all outputs # Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow? # TODO: can outputs be included on metadata of invocation schemas somehow?
def custom_openapi(): def custom_openapi() -> dict[str, Any]:
if app.openapi_schema: if app.openapi_schema:
return app.openapi_schema return app.openapi_schema
openapi_schema = get_openapi( openapi_schema = get_openapi(
@ -179,18 +180,18 @@ app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid a
@app.get("/docs", include_in_schema=False) @app.get("/docs", include_in_schema=False)
def overridden_swagger(): def overridden_swagger() -> HTMLResponse:
return get_swagger_ui_html( return get_swagger_ui_html(
openapi_url=app.openapi_url, openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string
title=app.title, title=app.title,
swagger_favicon_url="/static/docs/favicon.ico", swagger_favicon_url="/static/docs/favicon.ico",
) )
@app.get("/redoc", include_in_schema=False) @app.get("/redoc", include_in_schema=False)
def overridden_redoc(): def overridden_redoc() -> HTMLResponse:
return get_redoc_html( return get_redoc_html(
openapi_url=app.openapi_url, openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string
title=app.title, title=app.title,
redoc_favicon_url="/static/docs/favicon.ico", redoc_favicon_url="/static/docs/favicon.ico",
) )
@ -212,8 +213,8 @@ app.mount("/assets", StaticFiles(directory=Path(web_root_path, "dist/assets/")),
app.mount("/locales", StaticFiles(directory=Path(web_root_path, "dist/locales/")), name="locales") app.mount("/locales", StaticFiles(directory=Path(web_root_path, "dist/locales/")), name="locales")
def invoke_api(): def invoke_api() -> None:
def find_port(port: int): def find_port(port: int) -> int:
"""Find a port not in use starting at given port""" """Find a port not in use starting at given port"""
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon! # Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
# https://github.com/WaylonWalker # https://github.com/WaylonWalker