From da403ba04c41f85e007da6b19a088a2627406536 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 18 Oct 2023 13:43:01 +1100 Subject: [PATCH] fix(api): flesh out types for `api_app.py` --- invokeai/app/api_app.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 5889c7e228..6bdf358147 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -1,3 +1,5 @@ +from typing import Any +from fastapi.responses import HTMLResponse from .services.config import InvokeAIAppConfig # parse_args() must be called before any other imports. if it is not called first, consumers of the config @@ -13,7 +15,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c from inspect import signature from pathlib import Path - import torch import uvicorn from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware @@ -26,6 +27,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c from pydantic.json_schema import models_json_schema from fastapi.responses import FileResponse + # for PyCharm: # noinspection PyUnresolvedReferences import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) import invokeai.frontend.web as web_dir @@ -36,16 +38,15 @@ if True: # hack to make flake8 happy with imports coming after setting up the c from .api.routers import app_info, board_images, boards, images, models, session_queue, sessions, utilities from .api.sockets import SocketIO from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField + from torch.backends.mps import is_available as is_mps_available - if torch.backends.mps.is_available(): - # noinspection PyUnresolvedReferences + if is_mps_available(): import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import) app_config = InvokeAIAppConfig.get_config() app_config.parse_args() logger = InvokeAILogger.get_logger(config=app_config) - # fix for windows mimetypes registry entries being borked # see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352 mimetypes.add_type("application/javascript", ".js") @@ -78,13 +79,13 @@ app.add_middleware(GZipMiddleware, minimum_size=1000) # Add startup event to load dependencies @app.on_event("startup") -async def startup_event(): +async def startup_event() -> None: ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger) # Shut down threads @app.on_event("shutdown") -async def shutdown_event(): +async def shutdown_event() -> None: ApiDependencies.shutdown() @@ -108,7 +109,7 @@ app.include_router(session_queue.session_queue_router, prefix="/api") # Build a custom OpenAPI to include all outputs # TODO: can outputs be included on metadata of invocation schemas somehow? -def custom_openapi(): +def custom_openapi() -> dict[str, Any]: if app.openapi_schema: return app.openapi_schema openapi_schema = get_openapi( @@ -179,18 +180,18 @@ app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid a @app.get("/docs", include_in_schema=False) -def overridden_swagger(): +def overridden_swagger() -> HTMLResponse: return get_swagger_ui_html( - openapi_url=app.openapi_url, + openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string title=app.title, swagger_favicon_url="/static/docs/favicon.ico", ) @app.get("/redoc", include_in_schema=False) -def overridden_redoc(): +def overridden_redoc() -> HTMLResponse: return get_redoc_html( - openapi_url=app.openapi_url, + openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string title=app.title, redoc_favicon_url="/static/docs/favicon.ico", ) @@ -212,8 +213,8 @@ app.mount("/assets", StaticFiles(directory=Path(web_root_path, "dist/assets/")), app.mount("/locales", StaticFiles(directory=Path(web_root_path, "dist/locales/")), name="locales") -def invoke_api(): - def find_port(port: int): +def invoke_api() -> None: + def find_port(port: int) -> int: """Find a port not in use starting at given port""" # Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon! # https://github.com/WaylonWalker