mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
5 Commits
do-not-mer
...
feat/dynam
Author | SHA1 | Date | |
---|---|---|---|
3985c16183 | |||
751fe68d16 | |||
877348af49 | |||
3dbfee23e6 | |||
17314ea82d |
@ -1,22 +1,24 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Annotated, Optional, Union
|
||||
from typing import Annotated, Literal, Optional, Union
|
||||
|
||||
from fastapi import Body, HTTPException, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic.fields import Field
|
||||
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
|
||||
# Importing * is bad karma but needed here for node detection
|
||||
from ...invocations import * # noqa: F401 F403
|
||||
from ...invocations.baseinvocation import BaseInvocation
|
||||
from ...invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from ...services.graph import (
|
||||
Edge,
|
||||
EdgeConnection,
|
||||
Graph,
|
||||
GraphExecutionState,
|
||||
NodeAlreadyExecutedError,
|
||||
update_invocations_union,
|
||||
)
|
||||
from ...services.item_storage import PaginatedResults
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
||||
@ -38,6 +40,24 @@ async def create_session(
|
||||
return session
|
||||
|
||||
|
||||
@session_router.post(
|
||||
"/update_nodes",
|
||||
operation_id="update_nodes",
|
||||
)
|
||||
async def update_nodes() -> None:
|
||||
class TestFromRouterOutput(BaseInvocationOutput):
|
||||
type: Literal["test_from_router"] = "test_from_router"
|
||||
|
||||
class TestInvocationFromRouter(BaseInvocation):
|
||||
type: Literal["test_from_router_output"] = "test_from_router_output"
|
||||
|
||||
def invoke(self, context) -> TestFromRouterOutput:
|
||||
return TestFromRouterOutput()
|
||||
|
||||
# doesn't work from here... hmm...
|
||||
update_invocations_union()
|
||||
|
||||
|
||||
@session_router.get(
|
||||
"/",
|
||||
operation_id="list_sessions",
|
||||
|
@ -1,10 +1,13 @@
|
||||
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
import socket
|
||||
from inspect import signature
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@ -15,23 +18,17 @@ from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pydantic.schema import schema
|
||||
|
||||
from .services.config import InvokeAIAppConfig
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
import invokeai.frontend.web as web_dir
|
||||
import mimetypes
|
||||
|
||||
from invokeai.app.services.graph import update_invocations_union
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import sessions, models, images, boards, board_images, app_info
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
|
||||
|
||||
import torch
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, BaseInvocationOutput, UIConfigBase
|
||||
from .services.config import InvokeAIAppConfig
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
# noinspection PyUnresolvedReferences
|
||||
@ -104,8 +101,8 @@ app.include_router(app_info.app_router, prefix="/api")
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
# if app.openapi_schema:
|
||||
# return app.openapi_schema
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
description="An API for invoking AI image operations",
|
||||
@ -140,6 +137,9 @@ def custom_openapi():
|
||||
invoker_name = invoker.__name__
|
||||
output_type = signature(invoker.invoke).return_annotation
|
||||
output_type_title = output_type_titles[output_type.__name__]
|
||||
if invoker_name not in openapi_schema["components"]["schemas"]:
|
||||
openapi_schema["components"]["schemas"][invoker_name] = invoker.schema()
|
||||
|
||||
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
|
||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
||||
invoker_schema["output"] = outputs_ref
|
||||
@ -211,14 +211,14 @@ def invoke_api():
|
||||
|
||||
if app_config.dev_reload:
|
||||
try:
|
||||
import jurigged
|
||||
from invokeai.app.util.dev_reload import start_reloader
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.',
|
||||
exc_info=e,
|
||||
)
|
||||
else:
|
||||
jurigged.watch(logger=InvokeAILogger.getLogger(name="jurigged").info)
|
||||
start_reloader()
|
||||
|
||||
port = find_port(app_config.port)
|
||||
if port != app_config.port:
|
||||
@ -242,6 +242,26 @@ def invoke_api():
|
||||
for ch in logger.handlers:
|
||||
log.addHandler(ch)
|
||||
|
||||
class Test1Output(BaseInvocationOutput):
|
||||
type: Literal["test1_output"] = "test1_output"
|
||||
|
||||
class Test1Invocation(BaseInvocation):
|
||||
type: Literal["test1"] = "test1"
|
||||
|
||||
def invoke(self, context) -> Test1Output:
|
||||
return Test1Output()
|
||||
|
||||
class Test2Output(BaseInvocationOutput):
|
||||
type: Literal["test2_output"] = "test2_output"
|
||||
|
||||
class TestInvocation2(BaseInvocation):
|
||||
type: Literal["test2"] = "test2"
|
||||
|
||||
def invoke(self, context) -> Test2Output:
|
||||
return Test2Output()
|
||||
|
||||
update_invocations_union()
|
||||
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
|
||||
|
@ -7,7 +7,7 @@ from typing import Annotated, Any, Literal, Optional, Union, get_args, get_origi
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import BaseModel, root_validator, validator
|
||||
from pydantic.fields import Field
|
||||
from pydantic.fields import Field, ModelField
|
||||
|
||||
# Importing * is bad karma but needed here for node detection
|
||||
from ..invocations import * # noqa: F401 F403
|
||||
@ -232,7 +232,39 @@ InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
|
||||
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
class DynamicBaseModel(BaseModel):
|
||||
"""https://github.com/pydantic/pydantic/issues/1937#issuecomment-695313040"""
|
||||
|
||||
@classmethod
|
||||
def add_fields(cls, **field_definitions: Any):
|
||||
new_fields: dict[str, ModelField] = {}
|
||||
new_annotations: dict[str, Optional[type]] = {}
|
||||
|
||||
for f_name, f_def in field_definitions.items():
|
||||
if isinstance(f_def, tuple):
|
||||
try:
|
||||
f_annotation, f_value = f_def
|
||||
except ValueError as e:
|
||||
raise Exception(
|
||||
"field definitions should either be a tuple of (<type>, <default>) or just a "
|
||||
"default value, unfortunately this means tuples as "
|
||||
"default values are not allowed"
|
||||
) from e
|
||||
else:
|
||||
f_annotation, f_value = None, f_def
|
||||
|
||||
if f_annotation:
|
||||
new_annotations[f_name] = f_annotation
|
||||
|
||||
new_fields[f_name] = ModelField.infer(
|
||||
name=f_name, value=f_value, annotation=f_annotation, class_validators=None, config=cls.__config__
|
||||
)
|
||||
|
||||
cls.__fields__.update(new_fields)
|
||||
cls.__annotations__.update(new_annotations)
|
||||
|
||||
|
||||
class Graph(DynamicBaseModel):
|
||||
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
|
||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
||||
@ -700,7 +732,7 @@ class Graph(BaseModel):
|
||||
return g
|
||||
|
||||
|
||||
class GraphExecutionState(BaseModel):
|
||||
class GraphExecutionState(DynamicBaseModel):
|
||||
"""Tracks the state of a graph execution"""
|
||||
|
||||
id: str = Field(description="The id of the execution state", default_factory=lambda: uuid.uuid4().__str__())
|
||||
@ -1131,3 +1163,24 @@ class LibraryGraph(BaseModel):
|
||||
|
||||
|
||||
GraphInvocation.update_forward_refs()
|
||||
|
||||
|
||||
def update_invocations_union() -> None:
|
||||
global InvocationsUnion
|
||||
global InvocationOutputsUnion
|
||||
InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
|
||||
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
|
||||
|
||||
Graph.add_fields(
|
||||
nodes=(
|
||||
dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]],
|
||||
Field(description="The nodes in this graph", default_factory=dict),
|
||||
)
|
||||
)
|
||||
|
||||
GraphExecutionState.add_fields(
|
||||
results=(
|
||||
dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]],
|
||||
Field(description="The results of node executions", default_factory=dict),
|
||||
)
|
||||
)
|
||||
|
31
invokeai/app/util/dev_reload.py
Normal file
31
invokeai/app/util/dev_reload.py
Normal file
@ -0,0 +1,31 @@
|
||||
import jurigged
|
||||
from jurigged.codetools import ClassDefinition
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
logger = InvokeAILogger.getLogger(name=__name__)
|
||||
|
||||
|
||||
def reload_nodes(path: str, codefile: jurigged.CodeFile):
|
||||
"""Callback function for jurigged post-run events."""
|
||||
# Things we have access to here:
|
||||
# codefile.module:module - the module object associated with this file
|
||||
# codefile.module_name:str - the full module name (its key in sys.modules)
|
||||
# codefile.root:ModuleCode - an AST of the current source
|
||||
|
||||
# This is only reading top-level statements, not walking the whole AST, but class definition should be top-level, right?
|
||||
class_names = [statement.name for statement in codefile.root.children if isinstance(statement, ClassDefinition)]
|
||||
classes = [getattr(codefile.module, name) for name in class_names]
|
||||
invocations = [cls for cls in classes if issubclass(cls, BaseInvocation)]
|
||||
# outputs = [cls for cls in classes if issubclass(cls, BaseInvocationOutput)]
|
||||
|
||||
# We should assume jurigged has already replaced all references to methods of these classes,
|
||||
# but it hasn't re-executed any annotations on them (like @title or @tags).
|
||||
# We need to re-do anything that involved introspection like BaseInvocation.get_all_subclasses()
|
||||
logger.info("File reloaded: %s contains invocation classes %s", path, invocations)
|
||||
|
||||
|
||||
def start_reloader():
|
||||
watcher = jurigged.watch(logger=InvokeAILogger.getLogger(name="jurigged").info)
|
||||
watcher.postrun.register(reload_nodes, apply_history=False)
|
Reference in New Issue
Block a user