mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Do less stuff on import of api_app.py. Instead, call functions imperatively.
This commit is contained in:
parent
caa7c0f2bd
commit
0bdf7f5726
@ -2,6 +2,7 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import socket
|
import socket
|
||||||
|
import time
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -20,13 +21,10 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware
|
|||||||
from pydantic.json_schema import models_json_schema
|
from pydantic.json_schema import models_json_schema
|
||||||
from torch.backends.mps import is_available as is_mps_available
|
from torch.backends.mps import is_available as is_mps_available
|
||||||
|
|
||||||
# for PyCharm:
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
|
||||||
import invokeai.frontend.web as web_dir
|
import invokeai.frontend.web as web_dir
|
||||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config
|
||||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
@ -44,29 +42,12 @@ from .api.routers import (
|
|||||||
workflows,
|
workflows,
|
||||||
)
|
)
|
||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
from .invocations.baseinvocation import (
|
from .invocations.baseinvocation import BaseInvocation, UIConfigBase
|
||||||
BaseInvocation,
|
|
||||||
UIConfigBase,
|
|
||||||
)
|
|
||||||
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||||
|
|
||||||
app_config = get_config()
|
|
||||||
|
|
||||||
|
|
||||||
if is_mps_available():
|
|
||||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
|
||||||
|
|
||||||
|
|
||||||
logger = InvokeAILogger.get_logger(config=app_config)
|
|
||||||
# fix for windows mimetypes registry entries being borked
|
|
||||||
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
|
||||||
mimetypes.add_type("application/javascript", ".js")
|
|
||||||
mimetypes.add_type("text/css", ".css")
|
|
||||||
|
|
||||||
torch_device_name = TorchDevice.get_torch_device_name()
|
|
||||||
logger.info(f"Using torch device: {torch_device_name}")
|
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(ryand): Search for imports from api_app.py in the rest of the codebase and make sure I didn't break any of them.
|
||||||
|
def build_app(app_config: InvokeAIAppConfig, logger: logging.Logger) -> FastAPI:
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# Add startup event to load dependencies
|
# Add startup event to load dependencies
|
||||||
@ -75,9 +56,6 @@ async def lifespan(app: FastAPI):
|
|||||||
# Shut down threads
|
# Shut down threads
|
||||||
ApiDependencies.shutdown()
|
ApiDependencies.shutdown()
|
||||||
|
|
||||||
|
|
||||||
# Create the app
|
|
||||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Invoke - Community Edition",
|
title="Invoke - Community Edition",
|
||||||
docs_url=None,
|
docs_url=None,
|
||||||
@ -106,7 +84,6 @@ app.add_middleware(
|
|||||||
|
|
||||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||||
|
|
||||||
|
|
||||||
# Include all routers
|
# Include all routers
|
||||||
app.include_router(utilities.utilities_router, prefix="/api")
|
app.include_router(utilities.utilities_router, prefix="/api")
|
||||||
app.include_router(model_manager.model_manager_router, prefix="/api")
|
app.include_router(model_manager.model_manager_router, prefix="/api")
|
||||||
@ -118,6 +95,58 @@ app.include_router(app_info.app_router, prefix="/api")
|
|||||||
app.include_router(session_queue.session_queue_router, prefix="/api")
|
app.include_router(session_queue.session_queue_router, prefix="/api")
|
||||||
app.include_router(workflows.workflows_router, prefix="/api")
|
app.include_router(workflows.workflows_router, prefix="/api")
|
||||||
|
|
||||||
|
add_custom_openapi(app)
|
||||||
|
|
||||||
|
@app.get("/docs", include_in_schema=False)
|
||||||
|
def overridden_swagger() -> HTMLResponse:
|
||||||
|
return get_swagger_ui_html(
|
||||||
|
openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string
|
||||||
|
title=f"{app.title} - Swagger UI",
|
||||||
|
swagger_favicon_url="static/docs/invoke-favicon-docs.svg",
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/redoc", include_in_schema=False)
|
||||||
|
def overridden_redoc() -> HTMLResponse:
|
||||||
|
return get_redoc_html(
|
||||||
|
openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string
|
||||||
|
title=f"{app.title} - Redoc",
|
||||||
|
redoc_favicon_url="static/docs/invoke-favicon-docs.svg",
|
||||||
|
)
|
||||||
|
|
||||||
|
web_root_path = Path(list(web_dir.__path__)[0])
|
||||||
|
|
||||||
|
try:
|
||||||
|
app.mount("/", NoCacheStaticFiles(directory=Path(web_root_path, "dist"), html=True), name="ui")
|
||||||
|
except RuntimeError:
|
||||||
|
logger.warn(f"No UI found at {web_root_path}/dist, skipping UI mount")
|
||||||
|
|
||||||
|
app.mount(
|
||||||
|
"/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static"
|
||||||
|
) # docs favicon is in here
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def apply_monkeypatches() -> None:
|
||||||
|
# TODO(ryand): Don't monkeypatch on import!
|
||||||
|
if is_mps_available():
|
||||||
|
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||||
|
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||||
|
|
||||||
|
|
||||||
|
def fix_mimetypes():
|
||||||
|
# fix for windows mimetypes registry entries being borked
|
||||||
|
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
||||||
|
mimetypes.add_type("application/javascript", ".js")
|
||||||
|
mimetypes.add_type("text/css", ".css")
|
||||||
|
|
||||||
|
|
||||||
|
def add_custom_openapi(app: FastAPI) -> None:
|
||||||
|
"""Add a custom .openapi() method to the FastAPI app.
|
||||||
|
|
||||||
|
This is done based on the guidance here:
|
||||||
|
https://fastapi.tiangolo.com/how-to/extending-openapi/#normal-fastapi
|
||||||
|
"""
|
||||||
|
|
||||||
# Build a custom OpenAPI to include all outputs
|
# Build a custom OpenAPI to include all outputs
|
||||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||||
@ -195,39 +224,9 @@ def custom_openapi() -> dict[str, Any]:
|
|||||||
app.openapi_schema = openapi_schema
|
app.openapi_schema = openapi_schema
|
||||||
return app.openapi_schema
|
return app.openapi_schema
|
||||||
|
|
||||||
|
|
||||||
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
|
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
|
||||||
|
|
||||||
|
|
||||||
@app.get("/docs", include_in_schema=False)
|
|
||||||
def overridden_swagger() -> HTMLResponse:
|
|
||||||
return get_swagger_ui_html(
|
|
||||||
openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string
|
|
||||||
title=f"{app.title} - Swagger UI",
|
|
||||||
swagger_favicon_url="static/docs/invoke-favicon-docs.svg",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/redoc", include_in_schema=False)
|
|
||||||
def overridden_redoc() -> HTMLResponse:
|
|
||||||
return get_redoc_html(
|
|
||||||
openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string
|
|
||||||
title=f"{app.title} - Redoc",
|
|
||||||
redoc_favicon_url="static/docs/invoke-favicon-docs.svg",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
web_root_path = Path(list(web_dir.__path__)[0])
|
|
||||||
|
|
||||||
try:
|
|
||||||
app.mount("/", NoCacheStaticFiles(directory=Path(web_root_path, "dist"), html=True), name="ui")
|
|
||||||
except RuntimeError:
|
|
||||||
logger.warn(f"No UI found at {web_root_path}/dist, skipping UI mount")
|
|
||||||
app.mount(
|
|
||||||
"/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static"
|
|
||||||
) # docs favicon is in here
|
|
||||||
|
|
||||||
|
|
||||||
def check_cudnn(logger: logging.Logger) -> None:
|
def check_cudnn(logger: logging.Logger) -> None:
|
||||||
"""Check for cuDNN issues that could be causing degraded performance."""
|
"""Check for cuDNN issues that could be causing degraded performance."""
|
||||||
if torch.backends.cudnn.is_available():
|
if torch.backends.cudnn.is_available():
|
||||||
@ -244,7 +243,6 @@ def check_cudnn(logger: logging.Logger) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def invoke_api() -> None:
|
|
||||||
def find_port(port: int) -> int:
|
def find_port(port: int) -> int:
|
||||||
"""Find a port not in use starting at given port"""
|
"""Find a port not in use starting at given port"""
|
||||||
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
|
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
|
||||||
@ -255,7 +253,8 @@ def invoke_api() -> None:
|
|||||||
else:
|
else:
|
||||||
return port
|
return port
|
||||||
|
|
||||||
if app_config.dev_reload:
|
|
||||||
|
def init_dev_reload(logger: logging.Logger) -> None:
|
||||||
try:
|
try:
|
||||||
import jurigged
|
import jurigged
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@ -266,12 +265,29 @@ def invoke_api() -> None:
|
|||||||
else:
|
else:
|
||||||
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
|
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
|
||||||
|
|
||||||
|
|
||||||
|
def invoke_api() -> None:
|
||||||
|
start = time.time()
|
||||||
|
app_config = get_config()
|
||||||
|
logger = InvokeAILogger.get_logger(config=app_config)
|
||||||
|
|
||||||
|
apply_monkeypatches()
|
||||||
|
fix_mimetypes()
|
||||||
|
|
||||||
|
if app_config.dev_reload:
|
||||||
|
init_dev_reload(logger)
|
||||||
|
|
||||||
port = find_port(app_config.port)
|
port = find_port(app_config.port)
|
||||||
if port != app_config.port:
|
if port != app_config.port:
|
||||||
logger.warn(f"Port {app_config.port} in use, using port {port}")
|
logger.warn(f"Port {app_config.port} in use, using port {port}")
|
||||||
|
|
||||||
check_cudnn(logger)
|
check_cudnn(logger)
|
||||||
|
|
||||||
|
torch_device_name = TorchDevice.get_torch_device_name()
|
||||||
|
logger.info(f"Using torch device: {torch_device_name}")
|
||||||
|
|
||||||
|
app = build_app(app_config, logger)
|
||||||
|
|
||||||
# Start our own event loop for eventing usage
|
# Start our own event loop for eventing usage
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
config = uvicorn.Config(
|
config = uvicorn.Config(
|
||||||
@ -291,6 +307,7 @@ def invoke_api() -> None:
|
|||||||
log.handlers.clear()
|
log.handlers.clear()
|
||||||
for ch in logger.handlers:
|
for ch in logger.handlers:
|
||||||
log.addHandler(ch)
|
log.addHandler(ch)
|
||||||
|
logger.info(f"API started in {time.time() - start:.2f} seconds")
|
||||||
|
|
||||||
loop.run_until_complete(server.serve())
|
loop.run_until_complete(server.serve())
|
||||||
|
|
||||||
|
@ -7,10 +7,12 @@ import os
|
|||||||
|
|
||||||
from invokeai.app.run_app import run_app
|
from invokeai.app.run_app import run_app
|
||||||
|
|
||||||
logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage())
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
logging.getLogger("xformers").addFilter(
|
||||||
|
lambda record: "A matching Triton is not available" not in record.getMessage()
|
||||||
|
)
|
||||||
|
|
||||||
# Change working directory to the repo root
|
# Change working directory to the repo root
|
||||||
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
run_app()
|
run_app()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user