mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
5 Commits
feat/workf
...
feat/dynam
Author | SHA1 | Date | |
---|---|---|---|
3985c16183 | |||
751fe68d16 | |||
877348af49 | |||
3dbfee23e6 | |||
17314ea82d |
@ -1,22 +1,24 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from typing import Annotated, Optional, Union
|
from typing import Annotated, Literal, Optional, Union
|
||||||
|
|
||||||
from fastapi import Body, HTTPException, Path, Query, Response
|
from fastapi import Body, HTTPException, Path, Query, Response
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
|
from invokeai.app.services.item_storage import PaginatedResults
|
||||||
|
|
||||||
# Importing * is bad karma but needed here for node detection
|
# Importing * is bad karma but needed here for node detection
|
||||||
from ...invocations import * # noqa: F401 F403
|
from ...invocations import * # noqa: F401 F403
|
||||||
from ...invocations.baseinvocation import BaseInvocation
|
from ...invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||||
from ...services.graph import (
|
from ...services.graph import (
|
||||||
Edge,
|
Edge,
|
||||||
EdgeConnection,
|
EdgeConnection,
|
||||||
Graph,
|
Graph,
|
||||||
GraphExecutionState,
|
GraphExecutionState,
|
||||||
NodeAlreadyExecutedError,
|
NodeAlreadyExecutedError,
|
||||||
|
update_invocations_union,
|
||||||
)
|
)
|
||||||
from ...services.item_storage import PaginatedResults
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
||||||
@ -38,6 +40,24 @@ async def create_session(
|
|||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
@session_router.post(
|
||||||
|
"/update_nodes",
|
||||||
|
operation_id="update_nodes",
|
||||||
|
)
|
||||||
|
async def update_nodes() -> None:
|
||||||
|
class TestFromRouterOutput(BaseInvocationOutput):
|
||||||
|
type: Literal["test_from_router"] = "test_from_router"
|
||||||
|
|
||||||
|
class TestInvocationFromRouter(BaseInvocation):
|
||||||
|
type: Literal["test_from_router_output"] = "test_from_router_output"
|
||||||
|
|
||||||
|
def invoke(self, context) -> TestFromRouterOutput:
|
||||||
|
return TestFromRouterOutput()
|
||||||
|
|
||||||
|
# doesn't work from here... hmm...
|
||||||
|
update_invocations_union()
|
||||||
|
|
||||||
|
|
||||||
@session_router.get(
|
@session_router.get(
|
||||||
"/",
|
"/",
|
||||||
operation_id="list_sessions",
|
operation_id="list_sessions",
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import mimetypes
|
||||||
import socket
|
import socket
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import torch
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
@ -15,23 +18,17 @@ from fastapi_events.handlers.local import local_handler
|
|||||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||||
from pydantic.schema import schema
|
from pydantic.schema import schema
|
||||||
|
|
||||||
from .services.config import InvokeAIAppConfig
|
# noinspection PyUnresolvedReferences
|
||||||
from ..backend.util.logging import InvokeAILogger
|
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||||
|
|
||||||
from invokeai.version.invokeai_version import __version__
|
|
||||||
|
|
||||||
import invokeai.frontend.web as web_dir
|
import invokeai.frontend.web as web_dir
|
||||||
import mimetypes
|
from invokeai.app.services.graph import update_invocations_union
|
||||||
|
from invokeai.version.invokeai_version import __version__
|
||||||
from .api.dependencies import ApiDependencies
|
from .api.dependencies import ApiDependencies
|
||||||
from .api.routers import sessions, models, images, boards, board_images, app_info
|
from .api.routers import sessions, models, images, boards, board_images, app_info
|
||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
|
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, BaseInvocationOutput, UIConfigBase
|
||||||
|
from .services.config import InvokeAIAppConfig
|
||||||
import torch
|
from ..backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
|
||||||
|
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
@ -104,8 +101,8 @@ app.include_router(app_info.app_router, prefix="/api")
|
|||||||
# Build a custom OpenAPI to include all outputs
|
# Build a custom OpenAPI to include all outputs
|
||||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||||
def custom_openapi():
|
def custom_openapi():
|
||||||
if app.openapi_schema:
|
# if app.openapi_schema:
|
||||||
return app.openapi_schema
|
# return app.openapi_schema
|
||||||
openapi_schema = get_openapi(
|
openapi_schema = get_openapi(
|
||||||
title=app.title,
|
title=app.title,
|
||||||
description="An API for invoking AI image operations",
|
description="An API for invoking AI image operations",
|
||||||
@ -140,6 +137,9 @@ def custom_openapi():
|
|||||||
invoker_name = invoker.__name__
|
invoker_name = invoker.__name__
|
||||||
output_type = signature(invoker.invoke).return_annotation
|
output_type = signature(invoker.invoke).return_annotation
|
||||||
output_type_title = output_type_titles[output_type.__name__]
|
output_type_title = output_type_titles[output_type.__name__]
|
||||||
|
if invoker_name not in openapi_schema["components"]["schemas"]:
|
||||||
|
openapi_schema["components"]["schemas"][invoker_name] = invoker.schema()
|
||||||
|
|
||||||
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
|
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
|
||||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
||||||
invoker_schema["output"] = outputs_ref
|
invoker_schema["output"] = outputs_ref
|
||||||
@ -211,14 +211,14 @@ def invoke_api():
|
|||||||
|
|
||||||
if app_config.dev_reload:
|
if app_config.dev_reload:
|
||||||
try:
|
try:
|
||||||
import jurigged
|
from invokeai.app.util.dev_reload import start_reloader
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.',
|
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.',
|
||||||
exc_info=e,
|
exc_info=e,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
jurigged.watch(logger=InvokeAILogger.getLogger(name="jurigged").info)
|
start_reloader()
|
||||||
|
|
||||||
port = find_port(app_config.port)
|
port = find_port(app_config.port)
|
||||||
if port != app_config.port:
|
if port != app_config.port:
|
||||||
@ -242,6 +242,26 @@ def invoke_api():
|
|||||||
for ch in logger.handlers:
|
for ch in logger.handlers:
|
||||||
log.addHandler(ch)
|
log.addHandler(ch)
|
||||||
|
|
||||||
|
class Test1Output(BaseInvocationOutput):
|
||||||
|
type: Literal["test1_output"] = "test1_output"
|
||||||
|
|
||||||
|
class Test1Invocation(BaseInvocation):
|
||||||
|
type: Literal["test1"] = "test1"
|
||||||
|
|
||||||
|
def invoke(self, context) -> Test1Output:
|
||||||
|
return Test1Output()
|
||||||
|
|
||||||
|
class Test2Output(BaseInvocationOutput):
|
||||||
|
type: Literal["test2_output"] = "test2_output"
|
||||||
|
|
||||||
|
class TestInvocation2(BaseInvocation):
|
||||||
|
type: Literal["test2"] = "test2"
|
||||||
|
|
||||||
|
def invoke(self, context) -> Test2Output:
|
||||||
|
return Test2Output()
|
||||||
|
|
||||||
|
update_invocations_union()
|
||||||
|
|
||||||
loop.run_until_complete(server.serve())
|
loop.run_until_complete(server.serve())
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from typing import Annotated, Any, Literal, Optional, Union, get_args, get_origi
|
|||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from pydantic import BaseModel, root_validator, validator
|
from pydantic import BaseModel, root_validator, validator
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field, ModelField
|
||||||
|
|
||||||
# Importing * is bad karma but needed here for node detection
|
# Importing * is bad karma but needed here for node detection
|
||||||
from ..invocations import * # noqa: F401 F403
|
from ..invocations import * # noqa: F401 F403
|
||||||
@ -232,7 +232,39 @@ InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
|
|||||||
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
|
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class Graph(BaseModel):
|
class DynamicBaseModel(BaseModel):
|
||||||
|
"""https://github.com/pydantic/pydantic/issues/1937#issuecomment-695313040"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_fields(cls, **field_definitions: Any):
|
||||||
|
new_fields: dict[str, ModelField] = {}
|
||||||
|
new_annotations: dict[str, Optional[type]] = {}
|
||||||
|
|
||||||
|
for f_name, f_def in field_definitions.items():
|
||||||
|
if isinstance(f_def, tuple):
|
||||||
|
try:
|
||||||
|
f_annotation, f_value = f_def
|
||||||
|
except ValueError as e:
|
||||||
|
raise Exception(
|
||||||
|
"field definitions should either be a tuple of (<type>, <default>) or just a "
|
||||||
|
"default value, unfortunately this means tuples as "
|
||||||
|
"default values are not allowed"
|
||||||
|
) from e
|
||||||
|
else:
|
||||||
|
f_annotation, f_value = None, f_def
|
||||||
|
|
||||||
|
if f_annotation:
|
||||||
|
new_annotations[f_name] = f_annotation
|
||||||
|
|
||||||
|
new_fields[f_name] = ModelField.infer(
|
||||||
|
name=f_name, value=f_value, annotation=f_annotation, class_validators=None, config=cls.__config__
|
||||||
|
)
|
||||||
|
|
||||||
|
cls.__fields__.update(new_fields)
|
||||||
|
cls.__annotations__.update(new_annotations)
|
||||||
|
|
||||||
|
|
||||||
|
class Graph(DynamicBaseModel):
|
||||||
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
|
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
|
||||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||||
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
||||||
@ -700,7 +732,7 @@ class Graph(BaseModel):
|
|||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
class GraphExecutionState(BaseModel):
|
class GraphExecutionState(DynamicBaseModel):
|
||||||
"""Tracks the state of a graph execution"""
|
"""Tracks the state of a graph execution"""
|
||||||
|
|
||||||
id: str = Field(description="The id of the execution state", default_factory=lambda: uuid.uuid4().__str__())
|
id: str = Field(description="The id of the execution state", default_factory=lambda: uuid.uuid4().__str__())
|
||||||
@ -1131,3 +1163,24 @@ class LibraryGraph(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
GraphInvocation.update_forward_refs()
|
GraphInvocation.update_forward_refs()
|
||||||
|
|
||||||
|
|
||||||
|
def update_invocations_union() -> None:
|
||||||
|
global InvocationsUnion
|
||||||
|
global InvocationOutputsUnion
|
||||||
|
InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
|
||||||
|
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
|
||||||
|
|
||||||
|
Graph.add_fields(
|
||||||
|
nodes=(
|
||||||
|
dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]],
|
||||||
|
Field(description="The nodes in this graph", default_factory=dict),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
GraphExecutionState.add_fields(
|
||||||
|
results=(
|
||||||
|
dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]],
|
||||||
|
Field(description="The results of node executions", default_factory=dict),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
31
invokeai/app/util/dev_reload.py
Normal file
31
invokeai/app/util/dev_reload.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import jurigged
|
||||||
|
from jurigged.codetools import ClassDefinition
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
|
logger = InvokeAILogger.getLogger(name=__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def reload_nodes(path: str, codefile: jurigged.CodeFile):
|
||||||
|
"""Callback function for jurigged post-run events."""
|
||||||
|
# Things we have access to here:
|
||||||
|
# codefile.module:module - the module object associated with this file
|
||||||
|
# codefile.module_name:str - the full module name (its key in sys.modules)
|
||||||
|
# codefile.root:ModuleCode - an AST of the current source
|
||||||
|
|
||||||
|
# This is only reading top-level statements, not walking the whole AST, but class definition should be top-level, right?
|
||||||
|
class_names = [statement.name for statement in codefile.root.children if isinstance(statement, ClassDefinition)]
|
||||||
|
classes = [getattr(codefile.module, name) for name in class_names]
|
||||||
|
invocations = [cls for cls in classes if issubclass(cls, BaseInvocation)]
|
||||||
|
# outputs = [cls for cls in classes if issubclass(cls, BaseInvocationOutput)]
|
||||||
|
|
||||||
|
# We should assume jurigged has already replaced all references to methods of these classes,
|
||||||
|
# but it hasn't re-executed any annotations on them (like @title or @tags).
|
||||||
|
# We need to re-do anything that involved introspection like BaseInvocation.get_all_subclasses()
|
||||||
|
logger.info("File reloaded: %s contains invocation classes %s", path, invocations)
|
||||||
|
|
||||||
|
|
||||||
|
def start_reloader():
|
||||||
|
watcher = jurigged.watch(logger=InvokeAILogger.getLogger(name="jurigged").info)
|
||||||
|
watcher.postrun.register(reload_nodes, apply_history=False)
|
Reference in New Issue
Block a user