feat(api): chore: pydantic & fastapi upgrade

Upgrade pydantic and fastapi to latest.

- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1

**Big Changes**

There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.

**Invocations**

The biggest change relates to invocation creation, instantiation and validation.

Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.

Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.

With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.

This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.

In the end, this implementation is cleaner.

**Invocation Fields**

In pydantic v2, you can no longer directly add or remove fields from a model.

Previously, we did this to add the `type` field to invocations.

**Invocation Decorators**

With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.

A similar technique is used for `invocation_output()`.

**Minor Changes**

There are a number of minor changes around the pydantic v2 models API.

**Protected `model_` Namespace**

All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".

Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.

```py
class IPAdapterModelField(BaseModel):
    model_name: str = Field(description="Name of the IP-Adapter model")
    base_model: BaseModelType = Field(description="Base model")

    model_config = ConfigDict(protected_namespaces=())
```

**Model Serialization**

Pydantic models no longer have `Model.dict()` or `Model.json()`.

Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.

**Model Deserialization**

Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.

Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.

```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```

**Field Customisation**

Pydantic `Field`s no longer accept arbitrary args.

Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.

**Schema Customisation**

FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.

This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised

The specific aren't important, but this does present additional surface area for bugs.

**Performance Improvements**

Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.

I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
This commit is contained in:
psychedelicious
2023-09-24 18:11:07 +10:00
parent 19c5435332
commit c238a7f18b
74 changed files with 2788 additions and 3116 deletions

View File

@ -22,7 +22,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
from fastapi.staticfiles import StaticFiles
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.schema import schema
from pydantic.json_schema import models_json_schema
# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
@ -31,7 +31,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies
from .api.routers import app_info, board_images, boards, images, models, session_queue, sessions, utilities
from .api.routers import app_info, board_images, boards, images, models, session_queue, utilities
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
@ -51,7 +51,7 @@ mimetypes.add_type("text/css", ".css")
# Create the app
# TODO: create this all in a method so configuration/etc. can be passed in?
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None, separate_input_output_schemas=False)
# Add event handler
event_handler_id: int = id(app)
@ -63,18 +63,18 @@ app.add_middleware(
socket_io = SocketIO(app)
app.add_middleware(
CORSMiddleware,
allow_origins=app_config.allow_origins,
allow_credentials=app_config.allow_credentials,
allow_methods=app_config.allow_methods,
allow_headers=app_config.allow_headers,
)
# Add startup event to load dependencies
@app.on_event("startup")
async def startup_event():
app.add_middleware(
CORSMiddleware,
allow_origins=app_config.allow_origins,
allow_credentials=app_config.allow_credentials,
allow_methods=app_config.allow_methods,
allow_headers=app_config.allow_headers,
)
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
@ -85,12 +85,7 @@ async def shutdown_event():
# Include all routers
# TODO: REMOVE
# app.include_router(
# invocation.invocation_router,
# prefix = '/api')
app.include_router(sessions.session_router, prefix="/api")
# app.include_router(sessions.session_router, prefix="/api")
app.include_router(utilities.utilities_router, prefix="/api")
@ -117,6 +112,7 @@ def custom_openapi():
description="An API for invoking AI image operations",
version="1.0.0",
routes=app.routes,
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
)
# Add all outputs
@ -127,29 +123,32 @@ def custom_openapi():
output_type = signature(invoker.invoke).return_annotation
output_types.add(output_type)
output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
for schema_key, output_schema in output_schemas["definitions"].items():
output_schema["class"] = "output"
openapi_schema["components"]["schemas"][schema_key] = output_schema
output_schemas = models_json_schema(
models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}"
)
for schema_key, output_schema in output_schemas[1]["$defs"].items():
# TODO: note that we assume the schema_key here is the TYPE.__name__
# This could break in some cases, figure out a better way to do it
output_type_titles[schema_key] = output_schema["title"]
# Add Node Editor UI helper schemas
ui_config_schemas = schema([UIConfigBase, _InputField, _OutputField], ref_prefix="#/components/schemas/")
for schema_key, ui_config_schema in ui_config_schemas["definitions"].items():
ui_config_schemas = models_json_schema(
[(UIConfigBase, "serialization"), (_InputField, "serialization"), (_OutputField, "serialization")],
ref_template="#/components/schemas/{model}",
)
for schema_key, ui_config_schema in ui_config_schemas[1]["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = ui_config_schema
# Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations:
invoker_name = invoker.__name__
output_type = signature(invoker.invoke).return_annotation
output_type = signature(obj=invoker.invoke).return_annotation
output_type_title = output_type_titles[output_type.__name__]
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
invoker_schema["output"] = outputs_ref
invoker_schema["class"] = "invocation"
openapi_schema["components"]["schemas"][f"{output_type_title}"]["class"] = "output"
from invokeai.backend.model_management.models import get_model_config_enums
@ -172,7 +171,7 @@ def custom_openapi():
return app.openapi_schema
app.openapi = custom_openapi
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
# Override API doc favicons
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], "static/dream_web")), name="static")