Compare commits

...

5 Commits

4 changed files with 147 additions and 23 deletions

View File

@ -1,22 +1,24 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Annotated, Optional, Union
from typing import Annotated, Literal, Optional, Union
from fastapi import Body, HTTPException, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic.fields import Field
from invokeai.app.services.item_storage import PaginatedResults
# Importing * is bad karma but needed here for node detection
from ...invocations import * # noqa: F401 F403
from ...invocations.baseinvocation import BaseInvocation
from ...invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from ...services.graph import (
Edge,
EdgeConnection,
Graph,
GraphExecutionState,
NodeAlreadyExecutedError,
update_invocations_union,
)
from ...services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
@ -38,6 +40,24 @@ async def create_session(
return session
@session_router.post(
"/update_nodes",
operation_id="update_nodes",
)
async def update_nodes() -> None:
class TestFromRouterOutput(BaseInvocationOutput):
type: Literal["test_from_router"] = "test_from_router"
class TestInvocationFromRouter(BaseInvocation):
type: Literal["test_from_router_output"] = "test_from_router_output"
def invoke(self, context) -> TestFromRouterOutput:
return TestFromRouterOutput()
# doesn't work from here... hmm...
update_invocations_union()
@session_router.get(
"/",
operation_id="list_sessions",

View File

@ -1,10 +1,13 @@
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
import asyncio
import logging
import mimetypes
import socket
from inspect import signature
from pathlib import Path
from typing import Literal
import torch
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
@ -15,23 +18,17 @@ from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.schema import schema
from .services.config import InvokeAIAppConfig
from ..backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
import mimetypes
from invokeai.app.services.graph import update_invocations_union
from invokeai.version.invokeai_version import __version__
from .api.dependencies import ApiDependencies
from .api.routers import sessions, models, images, boards, board_images, app_info
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
import torch
# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, BaseInvocationOutput, UIConfigBase
from .services.config import InvokeAIAppConfig
from ..backend.util.logging import InvokeAILogger
if torch.backends.mps.is_available():
# noinspection PyUnresolvedReferences
@ -104,8 +101,8 @@ app.include_router(app_info.app_router, prefix="/api")
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
# if app.openapi_schema:
# return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
description="An API for invoking AI image operations",
@ -140,6 +137,9 @@ def custom_openapi():
invoker_name = invoker.__name__
output_type = signature(invoker.invoke).return_annotation
output_type_title = output_type_titles[output_type.__name__]
if invoker_name not in openapi_schema["components"]["schemas"]:
openapi_schema["components"]["schemas"][invoker_name] = invoker.schema()
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
invoker_schema["output"] = outputs_ref
@ -211,14 +211,14 @@ def invoke_api():
if app_config.dev_reload:
try:
import jurigged
from invokeai.app.util.dev_reload import start_reloader
except ImportError as e:
logger.error(
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.',
exc_info=e,
)
else:
jurigged.watch(logger=InvokeAILogger.getLogger(name="jurigged").info)
start_reloader()
port = find_port(app_config.port)
if port != app_config.port:
@ -242,6 +242,26 @@ def invoke_api():
for ch in logger.handlers:
log.addHandler(ch)
class Test1Output(BaseInvocationOutput):
type: Literal["test1_output"] = "test1_output"
class Test1Invocation(BaseInvocation):
type: Literal["test1"] = "test1"
def invoke(self, context) -> Test1Output:
return Test1Output()
class Test2Output(BaseInvocationOutput):
type: Literal["test2_output"] = "test2_output"
class TestInvocation2(BaseInvocation):
type: Literal["test2"] = "test2"
def invoke(self, context) -> Test2Output:
return Test2Output()
update_invocations_union()
loop.run_until_complete(server.serve())

View File

@ -7,7 +7,7 @@ from typing import Annotated, Any, Literal, Optional, Union, get_args, get_origi
import networkx as nx
from pydantic import BaseModel, root_validator, validator
from pydantic.fields import Field
from pydantic.fields import Field, ModelField
# Importing * is bad karma but needed here for node detection
from ..invocations import * # noqa: F401 F403
@ -232,7 +232,39 @@ InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
class Graph(BaseModel):
class DynamicBaseModel(BaseModel):
"""https://github.com/pydantic/pydantic/issues/1937#issuecomment-695313040"""
@classmethod
def add_fields(cls, **field_definitions: Any):
new_fields: dict[str, ModelField] = {}
new_annotations: dict[str, Optional[type]] = {}
for f_name, f_def in field_definitions.items():
if isinstance(f_def, tuple):
try:
f_annotation, f_value = f_def
except ValueError as e:
raise Exception(
"field definitions should either be a tuple of (<type>, <default>) or just a "
"default value, unfortunately this means tuples as "
"default values are not allowed"
) from e
else:
f_annotation, f_value = None, f_def
if f_annotation:
new_annotations[f_name] = f_annotation
new_fields[f_name] = ModelField.infer(
name=f_name, value=f_value, annotation=f_annotation, class_validators=None, config=cls.__config__
)
cls.__fields__.update(new_fields)
cls.__annotations__.update(new_annotations)
class Graph(DynamicBaseModel):
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
@ -700,7 +732,7 @@ class Graph(BaseModel):
return g
class GraphExecutionState(BaseModel):
class GraphExecutionState(DynamicBaseModel):
"""Tracks the state of a graph execution"""
id: str = Field(description="The id of the execution state", default_factory=lambda: uuid.uuid4().__str__())
@ -1131,3 +1163,24 @@ class LibraryGraph(BaseModel):
GraphInvocation.update_forward_refs()
def update_invocations_union() -> None:
global InvocationsUnion
global InvocationOutputsUnion
InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
Graph.add_fields(
nodes=(
dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]],
Field(description="The nodes in this graph", default_factory=dict),
)
)
GraphExecutionState.add_fields(
results=(
dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]],
Field(description="The results of node executions", default_factory=dict),
)
)

View File

@ -0,0 +1,31 @@
import jurigged
from jurigged.codetools import ClassDefinition
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.getLogger(name=__name__)
def reload_nodes(path: str, codefile: jurigged.CodeFile):
"""Callback function for jurigged post-run events."""
# Things we have access to here:
# codefile.module:module - the module object associated with this file
# codefile.module_name:str - the full module name (its key in sys.modules)
# codefile.root:ModuleCode - an AST of the current source
# This is only reading top-level statements, not walking the whole AST, but class definition should be top-level, right?
class_names = [statement.name for statement in codefile.root.children if isinstance(statement, ClassDefinition)]
classes = [getattr(codefile.module, name) for name in class_names]
invocations = [cls for cls in classes if issubclass(cls, BaseInvocation)]
# outputs = [cls for cls in classes if issubclass(cls, BaseInvocationOutput)]
# We should assume jurigged has already replaced all references to methods of these classes,
# but it hasn't re-executed any annotations on them (like @title or @tags).
# We need to re-do anything that involved introspection like BaseInvocation.get_all_subclasses()
logger.info("File reloaded: %s contains invocation classes %s", path, invocations)
def start_reloader():
watcher = jurigged.watch(logger=InvokeAILogger.getLogger(name="jurigged").info)
watcher.postrun.register(reload_nodes, apply_history=False)