mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Do less stuff on import of api_app.py. Instead, call functions imperatively.
This commit is contained in:
parent
caa7c0f2bd
commit
0bdf7f5726
@ -2,6 +2,7 @@ import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
import socket
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from inspect import signature
|
||||
from pathlib import Path
|
||||
@ -20,13 +21,10 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pydantic.json_schema import models_json_schema
|
||||
from torch.backends.mps import is_available as is_mps_available
|
||||
|
||||
# for PyCharm:
|
||||
# noinspection PyUnresolvedReferences
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
import invokeai.frontend.web as web_dir
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@ -44,188 +42,189 @@ from .api.routers import (
|
||||
workflows,
|
||||
)
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
UIConfigBase,
|
||||
)
|
||||
from .invocations.baseinvocation import BaseInvocation, UIConfigBase
|
||||
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||
|
||||
app_config = get_config()
|
||||
|
||||
# TODO(ryand): Search for imports from api_app.py in the rest of the codebase and make sure I didn't break any of them.
|
||||
def build_app(app_config: InvokeAIAppConfig, logger: logging.Logger) -> FastAPI:
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Add startup event to load dependencies
|
||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
||||
yield
|
||||
# Shut down threads
|
||||
ApiDependencies.shutdown()
|
||||
|
||||
app = FastAPI(
|
||||
title="Invoke - Community Edition",
|
||||
docs_url=None,
|
||||
redoc_url=None,
|
||||
separate_input_output_schemas=False,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Add event handler
|
||||
event_handler_id: int = id(app)
|
||||
app.add_middleware(
|
||||
EventHandlerASGIMiddleware,
|
||||
handlers=[local_handler], # TODO: consider doing this in services to support different configurations
|
||||
middleware_id=event_handler_id,
|
||||
)
|
||||
|
||||
socket_io = SocketIO(app)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=app_config.allow_origins,
|
||||
allow_credentials=app_config.allow_credentials,
|
||||
allow_methods=app_config.allow_methods,
|
||||
allow_headers=app_config.allow_headers,
|
||||
)
|
||||
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
|
||||
# Include all routers
|
||||
app.include_router(utilities.utilities_router, prefix="/api")
|
||||
app.include_router(model_manager.model_manager_router, prefix="/api")
|
||||
app.include_router(download_queue.download_queue_router, prefix="/api")
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
app.include_router(boards.boards_router, prefix="/api")
|
||||
app.include_router(board_images.board_images_router, prefix="/api")
|
||||
app.include_router(app_info.app_router, prefix="/api")
|
||||
app.include_router(session_queue.session_queue_router, prefix="/api")
|
||||
app.include_router(workflows.workflows_router, prefix="/api")
|
||||
|
||||
add_custom_openapi(app)
|
||||
|
||||
@app.get("/docs", include_in_schema=False)
|
||||
def overridden_swagger() -> HTMLResponse:
|
||||
return get_swagger_ui_html(
|
||||
openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string
|
||||
title=f"{app.title} - Swagger UI",
|
||||
swagger_favicon_url="static/docs/invoke-favicon-docs.svg",
|
||||
)
|
||||
|
||||
@app.get("/redoc", include_in_schema=False)
|
||||
def overridden_redoc() -> HTMLResponse:
|
||||
return get_redoc_html(
|
||||
openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string
|
||||
title=f"{app.title} - Redoc",
|
||||
redoc_favicon_url="static/docs/invoke-favicon-docs.svg",
|
||||
)
|
||||
|
||||
web_root_path = Path(list(web_dir.__path__)[0])
|
||||
|
||||
try:
|
||||
app.mount("/", NoCacheStaticFiles(directory=Path(web_root_path, "dist"), html=True), name="ui")
|
||||
except RuntimeError:
|
||||
logger.warn(f"No UI found at {web_root_path}/dist, skipping UI mount")
|
||||
|
||||
app.mount(
|
||||
"/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static"
|
||||
) # docs favicon is in here
|
||||
|
||||
return app
|
||||
|
||||
|
||||
if is_mps_available():
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
def apply_monkeypatches() -> None:
|
||||
# TODO(ryand): Don't monkeypatch on import!
|
||||
if is_mps_available():
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
# fix for windows mimetypes registry entries being borked
|
||||
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
|
||||
torch_device_name = TorchDevice.get_torch_device_name()
|
||||
logger.info(f"Using torch device: {torch_device_name}")
|
||||
def fix_mimetypes():
|
||||
# fix for windows mimetypes registry entries being borked
|
||||
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Add startup event to load dependencies
|
||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
||||
yield
|
||||
# Shut down threads
|
||||
ApiDependencies.shutdown()
|
||||
def add_custom_openapi(app: FastAPI) -> None:
|
||||
"""Add a custom .openapi() method to the FastAPI app.
|
||||
|
||||
This is done based on the guidance here:
|
||||
https://fastapi.tiangolo.com/how-to/extending-openapi/#normal-fastapi
|
||||
"""
|
||||
|
||||
# Create the app
|
||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||
app = FastAPI(
|
||||
title="Invoke - Community Edition",
|
||||
docs_url=None,
|
||||
redoc_url=None,
|
||||
separate_input_output_schemas=False,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
def custom_openapi() -> dict[str, Any]:
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
description="An API for invoking AI image operations",
|
||||
version="1.0.0",
|
||||
routes=app.routes,
|
||||
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
|
||||
)
|
||||
|
||||
# Add event handler
|
||||
event_handler_id: int = id(app)
|
||||
app.add_middleware(
|
||||
EventHandlerASGIMiddleware,
|
||||
handlers=[local_handler], # TODO: consider doing this in services to support different configurations
|
||||
middleware_id=event_handler_id,
|
||||
)
|
||||
# Add all outputs
|
||||
all_invocations = BaseInvocation.get_invocations()
|
||||
output_types = set()
|
||||
output_type_titles = {}
|
||||
for invoker in all_invocations:
|
||||
output_type = signature(invoker.invoke).return_annotation
|
||||
output_types.add(output_type)
|
||||
|
||||
socket_io = SocketIO(app)
|
||||
output_schemas = models_json_schema(
|
||||
models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}"
|
||||
)
|
||||
for schema_key, output_schema in output_schemas[1]["$defs"].items():
|
||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
||||
# This could break in some cases, figure out a better way to do it
|
||||
output_type_titles[schema_key] = output_schema["title"]
|
||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=app_config.allow_origins,
|
||||
allow_credentials=app_config.allow_credentials,
|
||||
allow_methods=app_config.allow_methods,
|
||||
allow_headers=app_config.allow_headers,
|
||||
)
|
||||
# Some models don't end up in the schemas as standalone definitions
|
||||
additional_schemas = models_json_schema(
|
||||
[
|
||||
(UIConfigBase, "serialization"),
|
||||
(InputFieldJSONSchemaExtra, "serialization"),
|
||||
(OutputFieldJSONSchemaExtra, "serialization"),
|
||||
(ModelIdentifierField, "serialization"),
|
||||
(ProgressImage, "serialization"),
|
||||
],
|
||||
ref_template="#/components/schemas/{model}",
|
||||
)
|
||||
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = schema_json
|
||||
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||
for invoker in all_invocations:
|
||||
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
|
||||
output_type = signature(obj=invoker.invoke).return_annotation
|
||||
output_type_title = output_type_titles[output_type.__name__]
|
||||
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
|
||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
||||
invoker_schema["output"] = outputs_ref
|
||||
invoker_schema["class"] = "invocation"
|
||||
|
||||
# This code no longer seems to be necessary?
|
||||
# Leave it here just in case
|
||||
#
|
||||
# from invokeai.backend.model_manager import get_model_config_formats
|
||||
# formats = get_model_config_formats()
|
||||
# for model_config_name, enum_set in formats.items():
|
||||
|
||||
# Include all routers
|
||||
app.include_router(utilities.utilities_router, prefix="/api")
|
||||
app.include_router(model_manager.model_manager_router, prefix="/api")
|
||||
app.include_router(download_queue.download_queue_router, prefix="/api")
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
app.include_router(boards.boards_router, prefix="/api")
|
||||
app.include_router(board_images.board_images_router, prefix="/api")
|
||||
app.include_router(app_info.app_router, prefix="/api")
|
||||
app.include_router(session_queue.session_queue_router, prefix="/api")
|
||||
app.include_router(workflows.workflows_router, prefix="/api")
|
||||
# if model_config_name in openapi_schema["components"]["schemas"]:
|
||||
# # print(f"Config with name {name} already defined")
|
||||
# continue
|
||||
|
||||
# openapi_schema["components"]["schemas"][model_config_name] = {
|
||||
# "title": model_config_name,
|
||||
# "description": "An enumeration.",
|
||||
# "type": "string",
|
||||
# "enum": [v.value for v in enum_set],
|
||||
# }
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
def custom_openapi() -> dict[str, Any]:
|
||||
if app.openapi_schema:
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
description="An API for invoking AI image operations",
|
||||
version="1.0.0",
|
||||
routes=app.routes,
|
||||
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
|
||||
)
|
||||
|
||||
# Add all outputs
|
||||
all_invocations = BaseInvocation.get_invocations()
|
||||
output_types = set()
|
||||
output_type_titles = {}
|
||||
for invoker in all_invocations:
|
||||
output_type = signature(invoker.invoke).return_annotation
|
||||
output_types.add(output_type)
|
||||
|
||||
output_schemas = models_json_schema(
|
||||
models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}"
|
||||
)
|
||||
for schema_key, output_schema in output_schemas[1]["$defs"].items():
|
||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
||||
# This could break in some cases, figure out a better way to do it
|
||||
output_type_titles[schema_key] = output_schema["title"]
|
||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
|
||||
|
||||
# Some models don't end up in the schemas as standalone definitions
|
||||
additional_schemas = models_json_schema(
|
||||
[
|
||||
(UIConfigBase, "serialization"),
|
||||
(InputFieldJSONSchemaExtra, "serialization"),
|
||||
(OutputFieldJSONSchemaExtra, "serialization"),
|
||||
(ModelIdentifierField, "serialization"),
|
||||
(ProgressImage, "serialization"),
|
||||
],
|
||||
ref_template="#/components/schemas/{model}",
|
||||
)
|
||||
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = schema_json
|
||||
|
||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||
for invoker in all_invocations:
|
||||
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
|
||||
output_type = signature(obj=invoker.invoke).return_annotation
|
||||
output_type_title = output_type_titles[output_type.__name__]
|
||||
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
|
||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
||||
invoker_schema["output"] = outputs_ref
|
||||
invoker_schema["class"] = "invocation"
|
||||
|
||||
# This code no longer seems to be necessary?
|
||||
# Leave it here just in case
|
||||
#
|
||||
# from invokeai.backend.model_manager import get_model_config_formats
|
||||
# formats = get_model_config_formats()
|
||||
# for model_config_name, enum_set in formats.items():
|
||||
|
||||
# if model_config_name in openapi_schema["components"]["schemas"]:
|
||||
# # print(f"Config with name {name} already defined")
|
||||
# continue
|
||||
|
||||
# openapi_schema["components"]["schemas"][model_config_name] = {
|
||||
# "title": model_config_name,
|
||||
# "description": "An enumeration.",
|
||||
# "type": "string",
|
||||
# "enum": [v.value for v in enum_set],
|
||||
# }
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
|
||||
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
|
||||
|
||||
|
||||
@app.get("/docs", include_in_schema=False)
|
||||
def overridden_swagger() -> HTMLResponse:
|
||||
return get_swagger_ui_html(
|
||||
openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string
|
||||
title=f"{app.title} - Swagger UI",
|
||||
swagger_favicon_url="static/docs/invoke-favicon-docs.svg",
|
||||
)
|
||||
|
||||
|
||||
@app.get("/redoc", include_in_schema=False)
|
||||
def overridden_redoc() -> HTMLResponse:
|
||||
return get_redoc_html(
|
||||
openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string
|
||||
title=f"{app.title} - Redoc",
|
||||
redoc_favicon_url="static/docs/invoke-favicon-docs.svg",
|
||||
)
|
||||
|
||||
|
||||
web_root_path = Path(list(web_dir.__path__)[0])
|
||||
|
||||
try:
|
||||
app.mount("/", NoCacheStaticFiles(directory=Path(web_root_path, "dist"), html=True), name="ui")
|
||||
except RuntimeError:
|
||||
logger.warn(f"No UI found at {web_root_path}/dist, skipping UI mount")
|
||||
app.mount(
|
||||
"/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static"
|
||||
) # docs favicon is in here
|
||||
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
|
||||
|
||||
|
||||
def check_cudnn(logger: logging.Logger) -> None:
|
||||
@ -244,27 +243,39 @@ def check_cudnn(logger: logging.Logger) -> None:
|
||||
)
|
||||
|
||||
|
||||
def find_port(port: int) -> int:
|
||||
"""Find a port not in use starting at given port"""
|
||||
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
|
||||
# https://github.com/WaylonWalker
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
if s.connect_ex(("localhost", port)) == 0:
|
||||
return find_port(port=port + 1)
|
||||
else:
|
||||
return port
|
||||
|
||||
|
||||
def init_dev_reload(logger: logging.Logger) -> None:
|
||||
try:
|
||||
import jurigged
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.',
|
||||
exc_info=e,
|
||||
)
|
||||
else:
|
||||
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
|
||||
|
||||
|
||||
def invoke_api() -> None:
|
||||
def find_port(port: int) -> int:
|
||||
"""Find a port not in use starting at given port"""
|
||||
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
|
||||
# https://github.com/WaylonWalker
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
if s.connect_ex(("localhost", port)) == 0:
|
||||
return find_port(port=port + 1)
|
||||
else:
|
||||
return port
|
||||
start = time.time()
|
||||
app_config = get_config()
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
|
||||
apply_monkeypatches()
|
||||
fix_mimetypes()
|
||||
|
||||
if app_config.dev_reload:
|
||||
try:
|
||||
import jurigged
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.',
|
||||
exc_info=e,
|
||||
)
|
||||
else:
|
||||
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
|
||||
init_dev_reload(logger)
|
||||
|
||||
port = find_port(app_config.port)
|
||||
if port != app_config.port:
|
||||
@ -272,6 +283,11 @@ def invoke_api() -> None:
|
||||
|
||||
check_cudnn(logger)
|
||||
|
||||
torch_device_name = TorchDevice.get_torch_device_name()
|
||||
logger.info(f"Using torch device: {torch_device_name}")
|
||||
|
||||
app = build_app(app_config, logger)
|
||||
|
||||
# Start our own event loop for eventing usage
|
||||
loop = asyncio.new_event_loop()
|
||||
config = uvicorn.Config(
|
||||
@ -291,6 +307,7 @@ def invoke_api() -> None:
|
||||
log.handlers.clear()
|
||||
for ch in logger.handlers:
|
||||
log.addHandler(ch)
|
||||
logger.info(f"API started in {time.time() - start:.2f} seconds")
|
||||
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
|
@ -7,10 +7,12 @@ import os
|
||||
|
||||
from invokeai.app.run_app import run_app
|
||||
|
||||
logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage())
|
||||
|
||||
|
||||
def main():
|
||||
logging.getLogger("xformers").addFilter(
|
||||
lambda record: "A matching Triton is not available" not in record.getMessage()
|
||||
)
|
||||
|
||||
# Change working directory to the repo root
|
||||
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
run_app()
|
||||
|
Loading…
Reference in New Issue
Block a user