mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
2f9ebdec69
Some tech debt related to dynamic pydantic schemas for invocations became problematic. Including the invocations and results in the event schemas was breaking pydantic's handling of ref schemas. I don't really understand why - I think it's a pydantic bug in a remote edge case that we are hitting. After many failed attempts I landed on this implementation, which is actually much tidier than what was in there before. - Create pydantic-enabled types for `AnyInvocation` and `AnyInvocationOutput` and use these in place of the janky dynamic unions. Actually, they are kinda the same, but better encapsulated. Use these in `Graph`, `GraphExecutionState`, `InvocationEventBase` and `InvocationCompleteEvent`. - Revise the custom openapi function to work with the new models. - Split out the custom openapi function to a separate file. Add a `post_transform` callback so consumers can customize the output schema. - Update makefile scripts.
212 lines
7.1 KiB
Python
212 lines
7.1 KiB
Python
import asyncio
|
|
import logging
|
|
import mimetypes
|
|
import socket
|
|
from contextlib import asynccontextmanager
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import uvicorn
|
|
from fastapi import FastAPI
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.middleware.gzip import GZipMiddleware
|
|
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
|
from fastapi.responses import HTMLResponse
|
|
from fastapi_events.handlers.local import local_handler
|
|
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
|
from torch.backends.mps import is_available as is_mps_available
|
|
|
|
# for PyCharm:
|
|
# noinspection PyUnresolvedReferences
|
|
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
|
import invokeai.frontend.web as web_dir
|
|
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
|
from invokeai.app.services.config.config_default import get_config
|
|
from invokeai.app.util.custom_openapi import get_openapi_func
|
|
from invokeai.backend.util.devices import TorchDevice
|
|
|
|
from ..backend.util.logging import InvokeAILogger
|
|
from .api.dependencies import ApiDependencies
|
|
from .api.routers import (
|
|
app_info,
|
|
board_images,
|
|
boards,
|
|
download_queue,
|
|
images,
|
|
model_manager,
|
|
session_queue,
|
|
utilities,
|
|
workflows,
|
|
)
|
|
from .api.sockets import SocketIO
|
|
|
|
app_config = get_config()
|
|
|
|
|
|
if is_mps_available():
|
|
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
|
|
|
|
|
logger = InvokeAILogger.get_logger(config=app_config)
|
|
# fix for windows mimetypes registry entries being borked
|
|
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
|
mimetypes.add_type("application/javascript", ".js")
|
|
mimetypes.add_type("text/css", ".css")
|
|
|
|
torch_device_name = TorchDevice.get_torch_device_name()
|
|
logger.info(f"Using torch device: {torch_device_name}")
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
# Add startup event to load dependencies
|
|
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
|
yield
|
|
# Shut down threads
|
|
ApiDependencies.shutdown()
|
|
|
|
|
|
# Create the app
|
|
# TODO: create this all in a method so configuration/etc. can be passed in?
|
|
app = FastAPI(
|
|
title="Invoke - Community Edition",
|
|
docs_url=None,
|
|
redoc_url=None,
|
|
separate_input_output_schemas=False,
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
# Add event handler
|
|
event_handler_id: int = id(app)
|
|
app.add_middleware(
|
|
EventHandlerASGIMiddleware,
|
|
handlers=[local_handler], # TODO: consider doing this in services to support different configurations
|
|
middleware_id=event_handler_id,
|
|
)
|
|
|
|
socket_io = SocketIO(app)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=app_config.allow_origins,
|
|
allow_credentials=app_config.allow_credentials,
|
|
allow_methods=app_config.allow_methods,
|
|
allow_headers=app_config.allow_headers,
|
|
)
|
|
|
|
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
|
|
|
|
|
# Include all routers
|
|
app.include_router(utilities.utilities_router, prefix="/api")
|
|
app.include_router(model_manager.model_manager_router, prefix="/api")
|
|
app.include_router(download_queue.download_queue_router, prefix="/api")
|
|
app.include_router(images.images_router, prefix="/api")
|
|
app.include_router(boards.boards_router, prefix="/api")
|
|
app.include_router(board_images.board_images_router, prefix="/api")
|
|
app.include_router(app_info.app_router, prefix="/api")
|
|
app.include_router(session_queue.session_queue_router, prefix="/api")
|
|
app.include_router(workflows.workflows_router, prefix="/api")
|
|
|
|
app.openapi = get_openapi_func(app)
|
|
|
|
|
|
@app.get("/docs", include_in_schema=False)
|
|
def overridden_swagger() -> HTMLResponse:
|
|
return get_swagger_ui_html(
|
|
openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string
|
|
title=f"{app.title} - Swagger UI",
|
|
swagger_favicon_url="static/docs/invoke-favicon-docs.svg",
|
|
)
|
|
|
|
|
|
@app.get("/redoc", include_in_schema=False)
|
|
def overridden_redoc() -> HTMLResponse:
|
|
return get_redoc_html(
|
|
openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string
|
|
title=f"{app.title} - Redoc",
|
|
redoc_favicon_url="static/docs/invoke-favicon-docs.svg",
|
|
)
|
|
|
|
|
|
web_root_path = Path(list(web_dir.__path__)[0])
|
|
|
|
try:
|
|
app.mount("/", NoCacheStaticFiles(directory=Path(web_root_path, "dist"), html=True), name="ui")
|
|
except RuntimeError:
|
|
logger.warn(f"No UI found at {web_root_path}/dist, skipping UI mount")
|
|
app.mount(
|
|
"/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static"
|
|
) # docs favicon is in here
|
|
|
|
|
|
def check_cudnn(logger: logging.Logger) -> None:
|
|
"""Check for cuDNN issues that could be causing degraded performance."""
|
|
if torch.backends.cudnn.is_available():
|
|
try:
|
|
# Note: At the time of writing (torch 2.2.1), torch.backends.cudnn.version() only raises an error the first
|
|
# time it is called. Subsequent calls will return the version number without complaining about a mismatch.
|
|
cudnn_version = torch.backends.cudnn.version()
|
|
logger.info(f"cuDNN version: {cudnn_version}")
|
|
except RuntimeError as e:
|
|
logger.warning(
|
|
"Encountered a cuDNN version issue. This may result in degraded performance. This issue is usually "
|
|
"caused by an incompatible cuDNN version installed in your python environment, or on the host "
|
|
f"system. Full error message:\n{e}"
|
|
)
|
|
|
|
|
|
def invoke_api() -> None:
|
|
def find_port(port: int) -> int:
|
|
"""Find a port not in use starting at given port"""
|
|
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
|
|
# https://github.com/WaylonWalker
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
if s.connect_ex(("localhost", port)) == 0:
|
|
return find_port(port=port + 1)
|
|
else:
|
|
return port
|
|
|
|
if app_config.dev_reload:
|
|
try:
|
|
import jurigged
|
|
except ImportError as e:
|
|
logger.error(
|
|
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.',
|
|
exc_info=e,
|
|
)
|
|
else:
|
|
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
|
|
|
|
port = find_port(app_config.port)
|
|
if port != app_config.port:
|
|
logger.warn(f"Port {app_config.port} in use, using port {port}")
|
|
|
|
check_cudnn(logger)
|
|
|
|
# Start our own event loop for eventing usage
|
|
loop = asyncio.new_event_loop()
|
|
config = uvicorn.Config(
|
|
app=app,
|
|
host=app_config.host,
|
|
port=port,
|
|
loop="asyncio",
|
|
log_level=app_config.log_level,
|
|
ssl_certfile=app_config.ssl_certfile,
|
|
ssl_keyfile=app_config.ssl_keyfile,
|
|
)
|
|
server = uvicorn.Server(config)
|
|
|
|
# replace uvicorn's loggers with InvokeAI's for consistent appearance
|
|
for logname in ["uvicorn.access", "uvicorn"]:
|
|
log = InvokeAILogger.get_logger(logname)
|
|
log.handlers.clear()
|
|
for ch in logger.handlers:
|
|
log.addHandler(ch)
|
|
|
|
loop.run_until_complete(server.serve())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
invoke_api()
|