mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(api): flesh out types for api_app.py
This commit is contained in:
parent
e4c45012f4
commit
da403ba04c
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Any
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
from .services.config import InvokeAIAppConfig
|
from .services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
|
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
|
||||||
@ -13,7 +15,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
|||||||
from inspect import signature
|
from inspect import signature
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
@ -26,6 +27,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
|||||||
from pydantic.json_schema import models_json_schema
|
from pydantic.json_schema import models_json_schema
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
|
|
||||||
|
# for PyCharm:
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||||
import invokeai.frontend.web as web_dir
|
import invokeai.frontend.web as web_dir
|
||||||
@ -36,16 +38,15 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
|||||||
from .api.routers import app_info, board_images, boards, images, models, session_queue, sessions, utilities
|
from .api.routers import app_info, board_images, boards, images, models, session_queue, sessions, utilities
|
||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
|
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
|
||||||
|
from torch.backends.mps import is_available as is_mps_available
|
||||||
|
|
||||||
if torch.backends.mps.is_available():
|
if is_mps_available():
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||||
|
|
||||||
|
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
app_config.parse_args()
|
app_config.parse_args()
|
||||||
logger = InvokeAILogger.get_logger(config=app_config)
|
logger = InvokeAILogger.get_logger(config=app_config)
|
||||||
|
|
||||||
# fix for windows mimetypes registry entries being borked
|
# fix for windows mimetypes registry entries being borked
|
||||||
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
||||||
mimetypes.add_type("application/javascript", ".js")
|
mimetypes.add_type("application/javascript", ".js")
|
||||||
@ -78,13 +79,13 @@ app.add_middleware(GZipMiddleware, minimum_size=1000)
|
|||||||
|
|
||||||
# Add startup event to load dependencies
|
# Add startup event to load dependencies
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event() -> None:
|
||||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
||||||
|
|
||||||
|
|
||||||
# Shut down threads
|
# Shut down threads
|
||||||
@app.on_event("shutdown")
|
@app.on_event("shutdown")
|
||||||
async def shutdown_event():
|
async def shutdown_event() -> None:
|
||||||
ApiDependencies.shutdown()
|
ApiDependencies.shutdown()
|
||||||
|
|
||||||
|
|
||||||
@ -108,7 +109,7 @@ app.include_router(session_queue.session_queue_router, prefix="/api")
|
|||||||
|
|
||||||
# Build a custom OpenAPI to include all outputs
|
# Build a custom OpenAPI to include all outputs
|
||||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||||
def custom_openapi():
|
def custom_openapi() -> dict[str, Any]:
|
||||||
if app.openapi_schema:
|
if app.openapi_schema:
|
||||||
return app.openapi_schema
|
return app.openapi_schema
|
||||||
openapi_schema = get_openapi(
|
openapi_schema = get_openapi(
|
||||||
@ -179,18 +180,18 @@ app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid a
|
|||||||
|
|
||||||
|
|
||||||
@app.get("/docs", include_in_schema=False)
|
@app.get("/docs", include_in_schema=False)
|
||||||
def overridden_swagger():
|
def overridden_swagger() -> HTMLResponse:
|
||||||
return get_swagger_ui_html(
|
return get_swagger_ui_html(
|
||||||
openapi_url=app.openapi_url,
|
openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string
|
||||||
title=app.title,
|
title=app.title,
|
||||||
swagger_favicon_url="/static/docs/favicon.ico",
|
swagger_favicon_url="/static/docs/favicon.ico",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/redoc", include_in_schema=False)
|
@app.get("/redoc", include_in_schema=False)
|
||||||
def overridden_redoc():
|
def overridden_redoc() -> HTMLResponse:
|
||||||
return get_redoc_html(
|
return get_redoc_html(
|
||||||
openapi_url=app.openapi_url,
|
openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string
|
||||||
title=app.title,
|
title=app.title,
|
||||||
redoc_favicon_url="/static/docs/favicon.ico",
|
redoc_favicon_url="/static/docs/favicon.ico",
|
||||||
)
|
)
|
||||||
@ -212,8 +213,8 @@ app.mount("/assets", StaticFiles(directory=Path(web_root_path, "dist/assets/")),
|
|||||||
app.mount("/locales", StaticFiles(directory=Path(web_root_path, "dist/locales/")), name="locales")
|
app.mount("/locales", StaticFiles(directory=Path(web_root_path, "dist/locales/")), name="locales")
|
||||||
|
|
||||||
|
|
||||||
def invoke_api():
|
def invoke_api() -> None:
|
||||||
def find_port(port: int):
|
def find_port(port: int) -> int:
|
||||||
"""Find a port not in use starting at given port"""
|
"""Find a port not in use starting at given port"""
|
||||||
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
|
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
|
||||||
# https://github.com/WaylonWalker
|
# https://github.com/WaylonWalker
|
||||||
|
Loading…
Reference in New Issue
Block a user