diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000000..8232fc4b93 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,6 @@ +[run] +omit='.env/*' +source='.' + +[report] +show_missing = true diff --git a/.gitignore b/.gitignore index 9adb0be85a..9b33e07164 100644 --- a/.gitignore +++ b/.gitignore @@ -68,6 +68,7 @@ htmlcov/ .cache nosetests.xml coverage.xml +cov.xml *.cover *.py,cover .hypothesis/ diff --git a/.pytest.ini b/.pytest.ini new file mode 100644 index 0000000000..16ccfafe80 --- /dev/null +++ b/.pytest.ini @@ -0,0 +1,5 @@ +[pytest] +DJANGO_SETTINGS_MODULE = webtas.settings +; python_files = tests.py test_*.py *_tests.py + +addopts = --cov=. --cov-config=.coveragerc --cov-report xml:cov.xml diff --git a/docs/contributing/ARCHITECTURE.md b/docs/contributing/ARCHITECTURE.md new file mode 100644 index 0000000000..d74df94492 --- /dev/null +++ b/docs/contributing/ARCHITECTURE.md @@ -0,0 +1,93 @@ +# Invoke.AI Architecture + +```mermaid +flowchart TB + + subgraph apps[Applications] + webui[WebUI] + cli[CLI] + + subgraph webapi[Web API] + api[HTTP API] + sio[Socket.IO] + end + + end + + subgraph invoke[Invoke] + direction LR + invoker + services + sessions + invocations + end + + subgraph core[AI Core] + Generate + end + + webui --> webapi + webapi --> invoke + cli --> invoke + + invoker --> services & sessions + invocations --> services + sessions --> invocations + + services --> core + + %% Styles + classDef sg fill:#5028C8,font-weight:bold,stroke-width:2,color:#fff,stroke:#14141A + classDef default stroke-width:2px,stroke:#F6B314,color:#fff,fill:#14141A + + class apps,webapi,invoke,core sg + +``` + +## Applications + +Applications are built on top of the invoke framework. They should construct `invoker` and then interact through it. They should avoid interacting directly with core code in order to support a variety of configurations. + +### Web UI + +The Web UI is built on top of an HTTP API built with [FastAPI](https://fastapi.tiangolo.com/) and [Socket.IO](https://socket.io/). The frontend code is found in `/frontend` and the backend code is found in `/ldm/invoke/app/api_app.py` and `/ldm/invoke/app/api/`. The code is further organized as such: + +| Component | Description | +| --- | --- | +| api_app.py | Sets up the API app, annotates the OpenAPI spec with additional data, and runs the API | +| dependencies | Creates all invoker services and the invoker, and provides them to the API | +| events | An eventing system that could in the future be adapted to support horizontal scale-out | +| sockets | The Socket.IO interface - handles listening to and emitting session events (events are defined in the events service module) | +| routers | API definitions for different areas of API functionality | + +### CLI + +The CLI is built automatically from invocation metadata, and also supports invocation piping and auto-linking. Code is available in `/ldm/invoke/app/cli_app.py`. + +## Invoke + +The Invoke framework provides the interface to the underlying AI systems and is built with flexibility and extensibility in mind. There are four major concepts: invoker, sessions, invocations, and services. + +### Invoker + +The invoker (`/ldm/invoke/app/services/invoker.py`) is the primary interface through which applications interact with the framework. Its primary purpose is to create, manage, and invoke sessions. It also maintains two sets of services: +- **invocation services**, which are used by invocations to interact with core functionality. +- **invoker services**, which are used by the invoker to manage sessions and manage the invocation queue. + +### Sessions + +Invocations and links between them form a graph, which is maintained in a session. Sessions can be queued for invocation, which will execute their graph (either the next ready invocation, or all invocations). Sessions also maintain execution history for the graph (including storage of any outputs). An invocation may be added to a session at any time, and there is capability to add and entire graph at once, as well as to automatically link new invocations to previous invocations. Invocations can not be deleted or modified once added. + +The session graph does not support looping. This is left as an application problem to prevent additional complexity in the graph. + +### Invocations + +Invocations represent individual units of execution, with inputs and outputs. All invocations are located in `/ldm/invoke/app/invocations`, and are all automatically discovered and made available in the applications. These are the primary way to expose new functionality in Invoke.AI, and the [implementation guide](INVOCATIONS.md) explains how to add new invocations. + +### Services + +Services provide invocations access AI Core functionality and other necessary functionality (e.g. image storage). These are available in `/ldm/invoke/app/services`. As a general rule, new services should provide an interface as an abstract base class, and may provide a lightweight local implementation by default in their module. The goal for all services should be to enable the usage of different implementations (e.g. using cloud storage for image storage), but should not load any module dependencies unless that implementation has been used (i.e. don't import anything that won't be used, especially if it's expensive to import). + +## AI Core + +The AI Core is represented by the rest of the code base (i.e. the code outside of `/ldm/invoke/app/`). diff --git a/docs/contributing/INVOCATIONS.md b/docs/contributing/INVOCATIONS.md new file mode 100644 index 0000000000..c8a97c19e4 --- /dev/null +++ b/docs/contributing/INVOCATIONS.md @@ -0,0 +1,105 @@ +# Invocations + +Invocations represent a single operation, its inputs, and its outputs. These operations and their outputs can be chained together to generate and modify images. + +## Creating a new invocation + +To create a new invocation, either find the appropriate module file in `/ldm/invoke/app/invocations` to add your invocation to, or create a new one in that folder. All invocations in that folder will be discovered and made available to the CLI and API automatically. Invocations make use of [typing](https://docs.python.org/3/library/typing.html) and [pydantic](https://pydantic-docs.helpmanual.io/) for validation and integration into the CLI and API. + +An invocation looks like this: + +```py +class UpscaleInvocation(BaseInvocation): + """Upscales an image.""" + type: Literal['upscale'] = 'upscale' + + # Inputs + image: Union[ImageField,None] = Field(description="The input image") + strength: float = Field(default=0.75, gt=0, le=1, description="The strength") + level: Literal[2,4] = Field(default=2, description = "The upscale level") + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get(self.image.image_type, self.image.image_name) + results = context.services.generate.upscale_and_reconstruct( + image_list = [[image, 0]], + upscale = (self.level, self.strength), + strength = 0.0, # GFPGAN strength + save_original = False, + image_callback = None, + ) + + # Results are image and seed, unwrap for now + # TODO: can this return multiple results? + image_type = ImageType.RESULT + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, results[0][0]) + return ImageOutput( + image = ImageField(image_type = image_type, image_name = image_name) + ) +``` + +Each portion is important to implement correctly. + +### Class definition and type +```py +class UpscaleInvocation(BaseInvocation): + """Upscales an image.""" + type: Literal['upscale'] = 'upscale' +``` +All invocations must derive from `BaseInvocation`. They should have a docstring that declares what they do in a single, short line. They should also have a `type` with a type hint that's `Literal["command_name"]`, where `command_name` is what the user will type on the CLI or use in the API to create this invocation. The `command_name` must be unique. The `type` must be assigned to the value of the literal in the type hint. + +### Inputs +```py + # Inputs + image: Union[ImageField,None] = Field(description="The input image") + strength: float = Field(default=0.75, gt=0, le=1, description="The strength") + level: Literal[2,4] = Field(default=2, description="The upscale level") +``` +Inputs consist of three parts: a name, a type hint, and a `Field` with default, description, and validation information. For example: +| Part | Value | Description | +| ---- | ----- | ----------- | +| Name | `strength` | This field is referred to as `strength` | +| Type Hint | `float` | This field must be of type `float` | +| Field | `Field(default=0.75, gt=0, le=1, description="The strength")` | The default value is `0.75`, the value must be in the range (0,1], and help text will show "The strength" for this field. | + +Notice that `image` has type `Union[ImageField,None]`. The `Union` allows this field to be parsed with `None` as a value, which enables linking to previous invocations. All fields should either provide a default value or allow `None` as a value, so that they can be overwritten with a linked output from another invocation. + +The special type `ImageField` is also used here. All images are passed as `ImageField`, which protects them from pydantic validation errors (since images only ever come from links). + +Finally, note that for all linking, the `type` of the linked fields must match. If the `name` also matches, then the field can be **automatically linked** to a previous invocation by name and matching. + +### Invoke Function +```py + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get(self.image.image_type, self.image.image_name) + results = context.services.generate.upscale_and_reconstruct( + image_list = [[image, 0]], + upscale = (self.level, self.strength), + strength = 0.0, # GFPGAN strength + save_original = False, + image_callback = None, + ) + + # Results are image and seed, unwrap for now + image_type = ImageType.RESULT + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, results[0][0]) + return ImageOutput( + image = ImageField(image_type = image_type, image_name = image_name) + ) +``` +The `invoke` function is the last portion of an invocation. It is provided an `InvocationContext` which contains services to perform work as well as a `session_id` for use as needed. It should return a class with output values that derives from `BaseInvocationOutput`. + +Before being called, the invocation will have all of its fields set from defaults, inputs, and finally links (overriding in that order). + +Assume that this invocation may be running simultaneously with other invocations, may be running on another machine, or in other interesting scenarios. If you need functionality, please provide it as a service in the `InvocationServices` class, and make sure it can be overridden. + +### Outputs +```py +class ImageOutput(BaseInvocationOutput): + """Base class for invocations that output an image""" + type: Literal['image'] = 'image' + + image: ImageField = Field(default=None, description="The output image") +``` +Output classes look like an invocation class without the invoke method. Prefer to use an existing output class if available, and prefer to name inputs the same as outputs when possible, to promote automatic invocation linking. diff --git a/ldm/generate.py b/ldm/generate.py index 413a1e25cb..256f214b25 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -1030,6 +1030,8 @@ class Generate: image_callback=None, prefix=None, ): + + results = [] for r in image_list: image, seed = r try: @@ -1083,6 +1085,10 @@ class Generate: else: r[0] = image + results.append([image, seed]) + + return results + def apply_textmask( self, image_path: str, prompt: str, callback, threshold: float = 0.5 ): diff --git a/ldm/invoke/app/api/dependencies.py b/ldm/invoke/app/api/dependencies.py new file mode 100644 index 0000000000..60dd522803 --- /dev/null +++ b/ldm/invoke/app/api/dependencies.py @@ -0,0 +1,83 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from argparse import Namespace +import os + +from ..services.processor import DefaultInvocationProcessor + +from ..services.graph import GraphExecutionState +from ..services.sqlite import SqliteItemStorage + +from ...globals import Globals + +from ..services.image_storage import DiskImageStorage +from ..services.invocation_queue import MemoryInvocationQueue +from ..services.invocation_services import InvocationServices +from ..services.invoker import Invoker, InvokerServices +from ..services.generate_initializer import get_generate +from .events import FastAPIEventService + + +# TODO: is there a better way to achieve this? +def check_internet()->bool: + ''' + Return true if the internet is reachable. + It does this by pinging huggingface.co. + ''' + import urllib.request + host = 'http://huggingface.co' + try: + urllib.request.urlopen(host,timeout=1) + return True + except: + return False + + +class ApiDependencies: + """Contains and initializes all dependencies for the API""" + invoker: Invoker = None + + @staticmethod + def initialize( + args, + config, + event_handler_id: int + ): + Globals.try_patchmatch = args.patchmatch + Globals.always_use_cpu = args.always_use_cpu + Globals.internet_available = args.internet_available and check_internet() + Globals.disable_xformers = not args.xformers + Globals.ckpt_convert = args.ckpt_convert + + # TODO: Use a logger + print(f'>> Internet connectivity is {Globals.internet_available}') + + generate = get_generate(args, config) + + events = FastAPIEventService(event_handler_id) + + output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../outputs')) + + images = DiskImageStorage(output_folder) + + services = InvocationServices( + generate = generate, + events = events, + images = images + ) + + # TODO: build a file/path manager? + db_location = os.path.join(output_folder, 'invokeai.db') + + invoker_services = InvokerServices( + queue = MemoryInvocationQueue(), + graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'), + processor = DefaultInvocationProcessor() + ) + + ApiDependencies.invoker = Invoker(services, invoker_services) + + @staticmethod + def shutdown(): + if ApiDependencies.invoker: + ApiDependencies.invoker.stop() diff --git a/ldm/invoke/app/api/events.py b/ldm/invoke/app/api/events.py new file mode 100644 index 0000000000..701b48a316 --- /dev/null +++ b/ldm/invoke/app/api/events.py @@ -0,0 +1,54 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +import asyncio +from queue import Empty, Queue +from typing import Any +from fastapi_events.dispatcher import dispatch +from ..services.events import EventServiceBase +import threading + +class FastAPIEventService(EventServiceBase): + event_handler_id: int + __queue: Queue + __stop_event: threading.Event + + def __init__(self, event_handler_id: int) -> None: + self.event_handler_id = event_handler_id + self.__queue = Queue() + self.__stop_event = threading.Event() + asyncio.create_task(self.__dispatch_from_queue(stop_event = self.__stop_event)) + + super().__init__() + + + def stop(self, *args, **kwargs): + self.__stop_event.set() + self.__queue.put(None) + + + def dispatch(self, event_name: str, payload: Any) -> None: + self.__queue.put(dict( + event_name = event_name, + payload = payload + )) + + + async def __dispatch_from_queue(self, stop_event: threading.Event): + """Get events on from the queue and dispatch them, from the correct thread""" + while not stop_event.is_set(): + try: + event = self.__queue.get(block = False) + if not event: # Probably stopping + continue + + dispatch( + event.get('event_name'), + payload = event.get('payload'), + middleware_id = self.event_handler_id) + + except Empty: + await asyncio.sleep(0.001) + pass + + except asyncio.CancelledError as e: + raise e # Raise a proper error diff --git a/ldm/invoke/app/api/routers/images.py b/ldm/invoke/app/api/routers/images.py new file mode 100644 index 0000000000..1ae116e49d --- /dev/null +++ b/ldm/invoke/app/api/routers/images.py @@ -0,0 +1,57 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from datetime import datetime, timezone +from fastapi import Path, UploadFile, Request +from fastapi.routing import APIRouter +from fastapi.responses import FileResponse, Response +from PIL import Image +from ...services.image_storage import ImageType +from ..dependencies import ApiDependencies + +images_router = APIRouter( + prefix = '/v1/images', + tags = ['images'] +) + + +@images_router.get('/{image_type}/{image_name}', + operation_id = 'get_image' + ) +async def get_image( + image_type: ImageType = Path(description = "The type of image to get"), + image_name: str = Path(description = "The name of the image to get") +): + """Gets a result""" + # TODO: This is not really secure at all. At least make sure only output results are served + filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name) + return FileResponse(filename) + +@images_router.post('/uploads/', + operation_id = 'upload_image', + responses = { + 201: {'description': 'The image was uploaded successfully'}, + 404: {'description': 'Session not found'} + }) +async def upload_image( + file: UploadFile, + request: Request +): + if not file.content_type.startswith('image'): + return Response(status_code = 415) + + contents = await file.read() + try: + im = Image.open(contents) + except: + # Error opening the image + return Response(status_code = 415) + + filename = f'{str(int(datetime.now(timezone.utc).timestamp()))}.png' + ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im) + + return Response( + status_code=201, + headers = { + 'Location': request.url_for('get_image', image_type=ImageType.UPLOAD, image_name=filename) + } + ) diff --git a/ldm/invoke/app/api/routers/sessions.py b/ldm/invoke/app/api/routers/sessions.py new file mode 100644 index 0000000000..77008ad6e4 --- /dev/null +++ b/ldm/invoke/app/api/routers/sessions.py @@ -0,0 +1,232 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from typing import List, Optional, Union, Annotated +from fastapi import Query, Path, Body +from fastapi.routing import APIRouter +from fastapi.responses import Response +from pydantic.fields import Field + +from ...services.item_storage import PaginatedResults +from ..dependencies import ApiDependencies +from ...invocations.baseinvocation import BaseInvocation +from ...services.graph import EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError +from ...invocations import * + +session_router = APIRouter( + prefix = '/v1/sessions', + tags = ['sessions'] +) + + +@session_router.post('/', + operation_id = 'create_session', + responses = { + 200: {"model": GraphExecutionState}, + 400: {'description': 'Invalid json'} + }) +async def create_session( + graph: Optional[Graph] = Body(default = None, description = "The graph to initialize the session with") +) -> GraphExecutionState: + """Creates a new session, optionally initializing it with an invocation graph""" + session = ApiDependencies.invoker.create_execution_state(graph) + return session + + +@session_router.get('/', + operation_id = 'list_sessions', + responses = { + 200: {"model": PaginatedResults[GraphExecutionState]} + }) +async def list_sessions( + page: int = Query(default = 0, description = "The page of results to get"), + per_page: int = Query(default = 10, description = "The number of results per page"), + query: str = Query(default = '', description = "The query string to search for") +) -> PaginatedResults[GraphExecutionState]: + """Gets a list of sessions, optionally searching""" + if filter == '': + result = ApiDependencies.invoker.invoker_services.graph_execution_manager.list(page, per_page) + else: + result = ApiDependencies.invoker.invoker_services.graph_execution_manager.search(query, page, per_page) + return result + + +@session_router.get('/{session_id}', + operation_id = 'get_session', + responses = { + 200: {"model": GraphExecutionState}, + 404: {'description': 'Session not found'} + }) +async def get_session( + session_id: str = Path(description = "The id of the session to get") +) -> GraphExecutionState: + """Gets a session""" + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + if session is None: + return Response(status_code = 404) + else: + return session + + +@session_router.post('/{session_id}/nodes', + operation_id = 'add_node', + responses = { + 200: {"model": str}, + 400: {'description': 'Invalid node or link'}, + 404: {'description': 'Session not found'} + } +) +async def add_node( + session_id: str = Path(description = "The id of the session"), + node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The node to add") +) -> str: + """Adds a node to the graph""" + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + if session is None: + return Response(status_code = 404) + + try: + session.add_node(node) + ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + return session.id + except NodeAlreadyExecutedError: + return Response(status_code = 400) + except IndexError: + return Response(status_code = 400) + + +@session_router.put('/{session_id}/nodes/{node_path}', + operation_id = 'update_node', + responses = { + 200: {"model": GraphExecutionState}, + 400: {'description': 'Invalid node or link'}, + 404: {'description': 'Session not found'} + } +) +async def update_node( + session_id: str = Path(description = "The id of the session"), + node_path: str = Path(description = "The path to the node in the graph"), + node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The new node") +) -> GraphExecutionState: + """Updates a node in the graph and removes all linked edges""" + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + if session is None: + return Response(status_code = 404) + + try: + session.update_node(node_path, node) + ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + return session + except NodeAlreadyExecutedError: + return Response(status_code = 400) + except IndexError: + return Response(status_code = 400) + + +@session_router.delete('/{session_id}/nodes/{node_path}', + operation_id = 'delete_node', + responses = { + 200: {"model": GraphExecutionState}, + 400: {'description': 'Invalid node or link'}, + 404: {'description': 'Session not found'} + } +) +async def delete_node( + session_id: str = Path(description = "The id of the session"), + node_path: str = Path(description = "The path to the node to delete") +) -> GraphExecutionState: + """Deletes a node in the graph and removes all linked edges""" + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + if session is None: + return Response(status_code = 404) + + try: + session.delete_node(node_path) + ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + return session + except NodeAlreadyExecutedError: + return Response(status_code = 400) + except IndexError: + return Response(status_code = 400) + + +@session_router.post('/{session_id}/edges', + operation_id = 'add_edge', + responses = { + 200: {"model": GraphExecutionState}, + 400: {'description': 'Invalid node or link'}, + 404: {'description': 'Session not found'} + } +) +async def add_edge( + session_id: str = Path(description = "The id of the session"), + edge: tuple[EdgeConnection, EdgeConnection] = Body(description = "The edge to add") +) -> GraphExecutionState: + """Adds an edge to the graph""" + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + if session is None: + return Response(status_code = 404) + + try: + session.add_edge(edge) + ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + return session + except NodeAlreadyExecutedError: + return Response(status_code = 400) + except IndexError: + return Response(status_code = 400) + + +# TODO: the edge being in the path here is really ugly, find a better solution +@session_router.delete('/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}', + operation_id = 'delete_edge', + responses = { + 200: {"model": GraphExecutionState}, + 400: {'description': 'Invalid node or link'}, + 404: {'description': 'Session not found'} + } +) +async def delete_edge( + session_id: str = Path(description = "The id of the session"), + from_node_id: str = Path(description = "The id of the node the edge is coming from"), + from_field: str = Path(description = "The field of the node the edge is coming from"), + to_node_id: str = Path(description = "The id of the node the edge is going to"), + to_field: str = Path(description = "The field of the node the edge is going to") +) -> GraphExecutionState: + """Deletes an edge from the graph""" + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + if session is None: + return Response(status_code = 404) + + try: + edge = (EdgeConnection(node_id = from_node_id, field = from_field), EdgeConnection(node_id = to_node_id, field = to_field)) + session.delete_edge(edge) + ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API? + return session + except NodeAlreadyExecutedError: + return Response(status_code = 400) + except IndexError: + return Response(status_code = 400) + + +@session_router.put('/{session_id}/invoke', + operation_id = 'invoke_session', + responses = { + 200: {"model": None}, + 202: {'description': 'The invocation is queued'}, + 400: {'description': 'The session has no invocations ready to invoke'}, + 404: {'description': 'Session not found'} + }) +async def invoke_session( + session_id: str = Path(description = "The id of the session to invoke"), + all: bool = Query(default = False, description = "Whether or not to invoke all remaining invocations") +) -> None: + """Invokes a session""" + session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id) + if session is None: + return Response(status_code = 404) + + if session.is_complete(): + return Response(status_code = 400) + + ApiDependencies.invoker.invoke(session, invoke_all = all) + return Response(status_code=202) diff --git a/ldm/invoke/app/api/sockets.py b/ldm/invoke/app/api/sockets.py new file mode 100644 index 0000000000..eb4d5403c0 --- /dev/null +++ b/ldm/invoke/app/api/sockets.py @@ -0,0 +1,36 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from fastapi import FastAPI +from fastapi_socketio import SocketManager +from fastapi_events.handlers.local import local_handler +from fastapi_events.typing import Event +from ..services.events import EventServiceBase + +class SocketIO: + __sio: SocketManager + + def __init__(self, app: FastAPI): + self.__sio = SocketManager(app = app) + self.__sio.on('subscribe', handler=self._handle_sub) + self.__sio.on('unsubscribe', handler=self._handle_unsub) + + local_handler.register( + event_name = EventServiceBase.session_event, + _func=self._handle_session_event + ) + + async def _handle_session_event(self, event: Event): + await self.__sio.emit( + event = event[1]['event'], + data = event[1]['data'], + room = event[1]['data']['graph_execution_state_id'] + ) + + async def _handle_sub(self, sid, data, *args, **kwargs): + if 'session' in data: + self.__sio.enter_room(sid, data['session']) + + # @app.sio.on('unsubscribe') + async def _handle_unsub(self, sid, data, *args, **kwargs): + if 'session' in data: + self.__sio.leave_room(sid, data['session']) diff --git a/ldm/invoke/app/api_app.py b/ldm/invoke/app/api_app.py new file mode 100644 index 0000000000..db79b0d7e8 --- /dev/null +++ b/ldm/invoke/app/api_app.py @@ -0,0 +1,164 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +import asyncio +from inspect import signature +from fastapi import FastAPI +from fastapi.openapi.utils import get_openapi +from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html +from fastapi.staticfiles import StaticFiles +from fastapi_events.middleware import EventHandlerASGIMiddleware +from fastapi_events.handlers.local import local_handler +from fastapi.middleware.cors import CORSMiddleware +from pydantic.schema import schema +import uvicorn +from .api.sockets import SocketIO +from .invocations import * +from .invocations.baseinvocation import BaseInvocation +from .api.routers import images, sessions +from .api.dependencies import ApiDependencies +from ..args import Args + +# Create the app +# TODO: create this all in a method so configuration/etc. can be passed in? +app = FastAPI( + title = "Invoke AI", + docs_url = None, + redoc_url = None +) + +# Add event handler +event_handler_id: int = id(app) +app.add_middleware( + EventHandlerASGIMiddleware, + handlers = [local_handler], # TODO: consider doing this in services to support different configurations + middleware_id = event_handler_id) + +# Add CORS +# TODO: use configuration for this +origins = [] +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +socket_io = SocketIO(app) + +config = {} + +# Add startup event to load dependencies +@app.on_event('startup') +async def startup_event(): + args = Args() + config = args.parse_args() + + ApiDependencies.initialize( + args = args, + config = config, + event_handler_id = event_handler_id + ) + +# Shut down threads +@app.on_event('shutdown') +async def shutdown_event(): + ApiDependencies.shutdown() + +# Include all routers +# TODO: REMOVE +# app.include_router( +# invocation.invocation_router, +# prefix = '/api') + +app.include_router( + sessions.session_router, + prefix = '/api' +) + +app.include_router( + images.images_router, + prefix = '/api' +) + +# Build a custom OpenAPI to include all outputs +# TODO: can outputs be included on metadata of invocation schemas somehow? +def custom_openapi(): + if app.openapi_schema: + return app.openapi_schema + openapi_schema = get_openapi( + title = app.title, + description = "An API for invoking AI image operations", + version = "1.0.0", + routes = app.routes + ) + + # Add all outputs + all_invocations = BaseInvocation.get_invocations() + output_types = set() + output_type_titles = dict() + for invoker in all_invocations: + output_type = signature(invoker.invoke).return_annotation + output_types.add(output_type) + + output_schemas = schema(output_types, ref_prefix="#/components/schemas/") + for schema_key, output_schema in output_schemas['definitions'].items(): + openapi_schema["components"]["schemas"][schema_key] = output_schema + + # TODO: note that we assume the schema_key here is the TYPE.__name__ + # This could break in some cases, figure out a better way to do it + output_type_titles[schema_key] = output_schema['title'] + + # Add a reference to the output type to additionalProperties of the invoker schema + for invoker in all_invocations: + invoker_name = invoker.__name__ + output_type = signature(invoker.invoke).return_annotation + output_type_title = output_type_titles[output_type.__name__] + invoker_schema = openapi_schema["components"]["schemas"][invoker_name] + outputs_ref = { '$ref': f'#/components/schemas/{output_type_title}' } + if 'additionalProperties' not in invoker_schema: + invoker_schema['additionalProperties'] = {} + + invoker_schema['additionalProperties']['outputs'] = outputs_ref + + app.openapi_schema = openapi_schema + return app.openapi_schema + +app.openapi = custom_openapi + +# Override API doc favicons +app.mount('/static', StaticFiles(directory='static/dream_web'), name='static') + +@app.get("/docs", include_in_schema=False) +def overridden_swagger(): + return get_swagger_ui_html( + openapi_url=app.openapi_url, + title=app.title, + swagger_favicon_url="/static/favicon.ico" + ) + +@app.get("/redoc", include_in_schema=False) +def overridden_redoc(): + return get_redoc_html( + openapi_url=app.openapi_url, + title=app.title, + redoc_favicon_url="/static/favicon.ico" + ) + +def invoke_api(): + # Start our own event loop for eventing usage + # TODO: determine if there's a better way to do this + loop = asyncio.new_event_loop() + config = uvicorn.Config( + app = app, + host = "0.0.0.0", + port = 9090, + loop = loop) + # Use access_log to turn off logging + + server = uvicorn.Server(config) + loop.run_until_complete(server.serve()) + + +if __name__ == "__main__": + invoke_api() diff --git a/ldm/invoke/app/cli_app.py b/ldm/invoke/app/cli_app.py new file mode 100644 index 0000000000..6071afabb2 --- /dev/null +++ b/ldm/invoke/app/cli_app.py @@ -0,0 +1,306 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +import argparse +import shlex +import os +import time +from typing import Any, Dict, Iterable, Literal, Union, get_args, get_origin, get_type_hints +from pydantic import BaseModel +from pydantic.fields import Field + +from .services.processor import DefaultInvocationProcessor + +from .services.graph import EdgeConnection, GraphExecutionState + +from .services.sqlite import SqliteItemStorage + +from .invocations.image import ImageField +from .services.generate_initializer import get_generate +from .services.image_storage import DiskImageStorage +from .services.invocation_queue import MemoryInvocationQueue +from .invocations.baseinvocation import BaseInvocation +from .services.invocation_services import InvocationServices +from .services.invoker import Invoker, InvokerServices +from .invocations import * +from ..args import Args +from .services.events import EventServiceBase + + +class InvocationCommand(BaseModel): + invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type") + + +class InvalidArgs(Exception): + pass + + +def get_invocation_parser() -> argparse.ArgumentParser: + + # Create invocation parser + parser = argparse.ArgumentParser() + def exit(*args, **kwargs): + raise InvalidArgs + parser.exit = exit + + subparsers = parser.add_subparsers(dest='type') + invocation_parsers = dict() + + # Add history parser + history_parser = subparsers.add_parser('history', help="Shows the invocation history") + history_parser.add_argument('count', nargs='?', default=5, type=int, help="The number of history entries to show") + + # Add default parser + default_parser = subparsers.add_parser('default', help="Define a default value for all inputs with a specified name") + default_parser.add_argument('input', type=str, help="The input field") + default_parser.add_argument('value', help="The default value") + + default_parser = subparsers.add_parser('reset_default', help="Resets a default value") + default_parser.add_argument('input', type=str, help="The input field") + + # Create subparsers for each invocation + invocations = BaseInvocation.get_all_subclasses() + for invocation in invocations: + hints = get_type_hints(invocation) + cmd_name = get_args(hints['type'])[0] + command_parser = subparsers.add_parser(cmd_name, help=invocation.__doc__) + invocation_parsers[cmd_name] = command_parser + + # Add linking capability + command_parser.add_argument('--link', '-l', action='append', nargs=3, + help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)") + + command_parser.add_argument('--link_node', '-ln', action='append', + help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)") + + # Convert all fields to arguments + fields = invocation.__fields__ + for name, field in fields.items(): + if name in ['id', 'type']: + continue + + if get_origin(field.type_) == Literal: + allowed_values = get_args(field.type_) + allowed_types = set() + for val in allowed_values: + allowed_types.add(type(val)) + allowed_types_list = list(allowed_types) + field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] + + command_parser.add_argument( + f"--{name}", + dest=name, + type=field_type, + default=field.default, + choices = allowed_values, + help=field.field_info.description + ) + else: + command_parser.add_argument( + f"--{name}", + dest=name, + type=field.type_, + default=field.default, + help=field.field_info.description + ) + + return parser + + +def get_invocation_command(invocation) -> str: + fields = invocation.__fields__.items() + type_hints = get_type_hints(type(invocation)) + command = [invocation.type] + for name,field in fields: + if name in ['id', 'type']: + continue + + # TODO: add links + + # Skip image fields when serializing command + type_hint = type_hints.get(name) or None + if type_hint is ImageField or ImageField in get_args(type_hint): + continue + + field_value = getattr(invocation, name) + field_default = field.default + if field_value != field_default: + if type_hint is str or str in get_args(type_hint): + command.append(f'--{name} "{field_value}"') + else: + command.append(f'--{name} {field_value}') + + return ' '.join(command) + + +def get_graph_execution_history(graph_execution_state: GraphExecutionState) -> Iterable[str]: + """Gets the history of fully-executed invocations for a graph execution""" + return (n for n in reversed(graph_execution_state.executed_history) if n in graph_execution_state.graph.nodes) + + +def generate_matching_edges(a: BaseInvocation, b: BaseInvocation) -> list[tuple[EdgeConnection, EdgeConnection]]: + """Generates all possible edges between two invocations""" + atype = type(a) + btype = type(b) + + aoutputtype = atype.get_output_type() + + afields = get_type_hints(aoutputtype) + bfields = get_type_hints(btype) + + matching_fields = set(afields.keys()).intersection(bfields.keys()) + + # Remove invalid fields + invalid_fields = set(['type', 'id']) + matching_fields = matching_fields.difference(invalid_fields) + + edges = [(EdgeConnection(node_id = a.id, field = field), EdgeConnection(node_id = b.id, field = field)) for field in matching_fields] + return edges + + +def invoke_cli(): + args = Args() + config = args.parse_args() + + generate = get_generate(args, config) + + # NOTE: load model on first use, uncomment to load at startup + # TODO: Make this a config option? + #generate.load_model() + + events = EventServiceBase() + + output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../outputs')) + + services = InvocationServices( + generate = generate, + events = events, + images = DiskImageStorage(output_folder) + ) + + # TODO: build a file/path manager? + db_location = os.path.join(output_folder, 'invokeai.db') + + invoker_services = InvokerServices( + queue = MemoryInvocationQueue(), + graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'), + processor = DefaultInvocationProcessor() + ) + + invoker = Invoker(services, invoker_services) + session = invoker.create_execution_state() + + parser = get_invocation_parser() + + # Uncomment to print out previous sessions at startup + # print(invoker_services.session_manager.list()) + + # Defaults storage + defaults: Dict[str, Any] = dict() + + while True: + try: + cmd_input = input("> ") + except KeyboardInterrupt: + # Ctrl-c exits + break + + if cmd_input in ['exit','q']: + break; + + if cmd_input in ['--help','help','h','?']: + parser.print_help() + continue + + try: + # Refresh the state of the session + session = invoker.invoker_services.graph_execution_manager.get(session.id) + history = list(get_graph_execution_history(session)) + + # Split the command for piping + cmds = cmd_input.split('|') + start_id = len(history) + current_id = start_id + new_invocations = list() + for cmd in cmds: + # Parse args to create invocation + args = vars(parser.parse_args(shlex.split(cmd.strip()))) + + # Check for special commands + # TODO: These might be better as Pydantic models, similar to the invocations + if args['type'] == 'history': + history_count = args['count'] or 5 + for i in range(min(history_count, len(history))): + entry_id = history[-1 - i] + entry = session.graph.get_node(entry_id) + print(f'{entry_id}: {get_invocation_command(entry.invocation)}') + continue + + if args['type'] == 'reset_default': + if args['input'] in defaults: + del defaults[args['input']] + continue + + if args['type'] == 'default': + field = args['input'] + field_value = args['value'] + defaults[field] = field_value + continue + + # Override defaults + for field_name,field_default in defaults.items(): + if field_name in args: + args[field_name] = field_default + + # Parse invocation + args['id'] = current_id + command = InvocationCommand(invocation = args) + + # Pipe previous command output (if there was a previous command) + edges = [] + if len(history) > 0 or current_id != start_id: + from_id = history[0] if current_id == start_id else str(current_id - 1) + from_node = next(filter(lambda n: n[0].id == from_id, new_invocations))[0] if current_id != start_id else session.graph.get_node(from_id) + matching_edges = generate_matching_edges(from_node, command.invocation) + edges.extend(matching_edges) + + # Parse provided links + if 'link_node' in args and args['link_node']: + for link in args['link_node']: + link_node = session.graph.get_node(link) + matching_edges = generate_matching_edges(link_node, command.invocation) + edges.extend(matching_edges) + + if 'link' in args and args['link']: + for link in args['link']: + edges.append((EdgeConnection(node_id = link[1], field = link[0]), EdgeConnection(node_id = command.invocation.id, field = link[2]))) + + new_invocations.append((command.invocation, edges)) + + current_id = current_id + 1 + + # Command line was parsed successfully + # Add the invocations to the session + for invocation in new_invocations: + session.add_node(invocation[0]) + for edge in invocation[1]: + session.add_edge(edge) + + # Execute all available invocations + invoker.invoke(session, invoke_all = True) + while not session.is_complete(): + # Wait some time + session = invoker.invoker_services.graph_execution_manager.get(session.id) + time.sleep(0.1) + + except InvalidArgs: + print('Invalid command, use "help" to list commands') + continue + + except SystemExit: + continue + + invoker.stop() + + +if __name__ == "__main__": + invoke_cli() diff --git a/ldm/invoke/app/invocations/__init__.py b/ldm/invoke/app/invocations/__init__.py new file mode 100644 index 0000000000..6407a1cdee --- /dev/null +++ b/ldm/invoke/app/invocations/__init__.py @@ -0,0 +1,8 @@ +import os + +__all__ = [] + +dirname = os.path.dirname(os.path.abspath(__file__)) +for f in os.listdir(dirname): + if f != "__init__.py" and os.path.isfile("%s/%s" % (dirname, f)) and f[-3:] == ".py": + __all__.append(f[:-3]) diff --git a/ldm/invoke/app/invocations/baseinvocation.py b/ldm/invoke/app/invocations/baseinvocation.py new file mode 100644 index 0000000000..1ad2d99112 --- /dev/null +++ b/ldm/invoke/app/invocations/baseinvocation.py @@ -0,0 +1,74 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from abc import ABC, abstractmethod +from inspect import signature +from typing import get_args, get_type_hints +from pydantic import BaseModel, Field +from ..services.invocation_services import InvocationServices + + +class InvocationContext: + services: InvocationServices + graph_execution_state_id: str + + def __init__(self, services: InvocationServices, graph_execution_state_id: str): + self.services = services + self.graph_execution_state_id = graph_execution_state_id + + +class BaseInvocationOutput(BaseModel): + """Base class for all invocation outputs""" + + # All outputs must include a type name like this: + # type: Literal['your_output_name'] + + @classmethod + def get_all_subclasses_tuple(cls): + subclasses = [] + toprocess = [cls] + while len(toprocess) > 0: + next = toprocess.pop(0) + next_subclasses = next.__subclasses__() + subclasses.extend(next_subclasses) + toprocess.extend(next_subclasses) + return tuple(subclasses) + + +class BaseInvocation(ABC, BaseModel): + """A node to process inputs and produce outputs. + May use dependency injection in __init__ to receive providers. + """ + + # All invocations must include a type name like this: + # type: Literal['your_output_name'] + + @classmethod + def get_all_subclasses(cls): + subclasses = [] + toprocess = [cls] + while len(toprocess) > 0: + next = toprocess.pop(0) + next_subclasses = next.__subclasses__() + subclasses.extend(next_subclasses) + toprocess.extend(next_subclasses) + return subclasses + + @classmethod + def get_invocations(cls): + return tuple(BaseInvocation.get_all_subclasses()) + + @classmethod + def get_invocations_map(cls): + # Get the type strings out of the literals and into a dictionary + return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseInvocation.get_all_subclasses())) + + @classmethod + def get_output_type(cls): + return signature(cls.invoke).return_annotation + + @abstractmethod + def invoke(self, context: InvocationContext) -> BaseInvocationOutput: + """Invoke with provided context and return outputs.""" + pass + + id: str = Field(description="The id of this node. Must be unique among all nodes.") diff --git a/ldm/invoke/app/invocations/cv.py b/ldm/invoke/app/invocations/cv.py new file mode 100644 index 0000000000..f950669736 --- /dev/null +++ b/ldm/invoke/app/invocations/cv.py @@ -0,0 +1,42 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from typing import Literal +import numpy +from pydantic import Field +from PIL import Image, ImageOps +import cv2 as cv +from .image import ImageField, ImageOutput +from .baseinvocation import BaseInvocation, InvocationContext +from ..services.image_storage import ImageType + + +class CvInpaintInvocation(BaseInvocation): + """Simple inpaint using opencv.""" + type: Literal['cv_inpaint'] = 'cv_inpaint' + + # Inputs + image: ImageField = Field(default=None, description="The image to inpaint") + mask: ImageField = Field(default=None, description="The mask to use when inpainting") + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get(self.image.image_type, self.image.image_name) + mask = context.services.images.get(self.mask.image_type, self.mask.image_name) + + # Convert to cv image/mask + # TODO: consider making these utility functions + cv_image = cv.cvtColor(numpy.array(image.convert('RGB')), cv.COLOR_RGB2BGR) + cv_mask = numpy.array(ImageOps.invert(mask)) + + # Inpaint + cv_inpainted = cv.inpaint(cv_image, cv_mask, 3, cv.INPAINT_TELEA) + + # Convert back to Pillow + # TODO: consider making a utility function + image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB)) + + image_type = ImageType.INTERMEDIATE + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, image_inpainted) + return ImageOutput( + image = ImageField(image_type = image_type, image_name = image_name) + ) diff --git a/ldm/invoke/app/invocations/generate.py b/ldm/invoke/app/invocations/generate.py new file mode 100644 index 0000000000..60b656bf0c --- /dev/null +++ b/ldm/invoke/app/invocations/generate.py @@ -0,0 +1,160 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from datetime import datetime, timezone +from typing import Any, Literal, Optional, Union +import numpy as np +from pydantic import Field +from PIL import Image +from skimage.exposure.histogram_matching import match_histograms +from .image import ImageField, ImageOutput +from .baseinvocation import BaseInvocation, InvocationContext +from ..services.image_storage import ImageType +from ..services.invocation_services import InvocationServices + + +SAMPLER_NAME_VALUES = Literal["ddim","plms","k_lms","k_dpm_2","k_dpm_2_a","k_euler","k_euler_a","k_heun"] + +# Text to image +class TextToImageInvocation(BaseInvocation): + """Generates an image using text2img.""" + type: Literal['txt2img'] = 'txt2img' + + # Inputs + # TODO: consider making prompt optional to enable providing prompt through a link + prompt: Optional[str] = Field(description="The prompt to generate an image from") + seed: int = Field(default=-1, ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)") + steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") + width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image") + height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image") + cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt") + sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use") + seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams") + model: str = Field(default='', description="The model to use (currently ignored)") + progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation") + + # TODO: pass this an emitter method or something? or a session for dispatching? + def dispatch_progress(self, context: InvocationContext, sample: Any = None, step: int = 0) -> None: + context.services.events.emit_generator_progress( + context.graph_execution_state_id, self.id, step, float(step) / float(self.steps) + ) + + def invoke(self, context: InvocationContext) -> ImageOutput: + + def step_callback(sample, step = 0): + self.dispatch_progress(context, sample, step) + + # Handle invalid model parameter + # TODO: figure out if this can be done via a validator that uses the model_cache + # TODO: How to get the default model name now? + if self.model is None or self.model == '': + self.model = context.services.generate.model_name + + # Set the model (if already cached, this does nothing) + context.services.generate.set_model(self.model) + + results = context.services.generate.prompt2image( + prompt = self.prompt, + step_callback = step_callback, + **self.dict(exclude = {'prompt'}) # Shorthand for passing all of the parameters above manually + ) + + # Results are image and seed, unwrap for now and ignore the seed + # TODO: pre-seed? + # TODO: can this return multiple results? Should it? + image_type = ImageType.RESULT + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, results[0][0]) + return ImageOutput( + image = ImageField(image_type = image_type, image_name = image_name) + ) + + +class ImageToImageInvocation(TextToImageInvocation): + """Generates an image using img2img.""" + type: Literal['img2img'] = 'img2img' + + # Inputs + image: Union[ImageField,None] = Field(description="The input image") + strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the original image") + fit: bool = Field(default=True, description="Whether or not the result should be fit to the aspect ratio of the input image") + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = None if self.image is None else context.services.images.get(self.image.image_type, self.image.image_name) + mask = None + + def step_callback(sample, step = 0): + self.dispatch_progress(context, sample, step) + + # Handle invalid model parameter + # TODO: figure out if this can be done via a validator that uses the model_cache + # TODO: How to get the default model name now? + if self.model is None or self.model == '': + self.model = context.services.generate.model_name + + # Set the model (if already cached, this does nothing) + context.services.generate.set_model(self.model) + + results = context.services.generate.prompt2image( + prompt = self.prompt, + init_img = image, + init_mask = mask, + step_callback = step_callback, + **self.dict(exclude = {'prompt','image','mask'}) # Shorthand for passing all of the parameters above manually + ) + + result_image = results[0][0] + + # Results are image and seed, unwrap for now and ignore the seed + # TODO: pre-seed? + # TODO: can this return multiple results? Should it? + image_type = ImageType.RESULT + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, result_image) + return ImageOutput( + image = ImageField(image_type = image_type, image_name = image_name) + ) + + +class InpaintInvocation(ImageToImageInvocation): + """Generates an image using inpaint.""" + type: Literal['inpaint'] = 'inpaint' + + # Inputs + mask: Union[ImageField,None] = Field(description="The mask") + inpaint_replace: float = Field(default=0.0, ge=0.0, le=1.0, description="The amount by which to replace masked areas with latent noise") + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = None if self.image is None else context.services.images.get(self.image.image_type, self.image.image_name) + mask = None if self.mask is None else context.services.images.get(self.mask.image_type, self.mask.image_name) + + def step_callback(sample, step = 0): + self.dispatch_progress(context, sample, step) + + # Handle invalid model parameter + # TODO: figure out if this can be done via a validator that uses the model_cache + # TODO: How to get the default model name now? + if self.model is None or self.model == '': + self.model = context.services.generate.model_name + + # Set the model (if already cached, this does nothing) + context.services.generate.set_model(self.model) + + results = context.services.generate.prompt2image( + prompt = self.prompt, + init_img = image, + init_mask = mask, + step_callback = step_callback, + **self.dict(exclude = {'prompt','image','mask'}) # Shorthand for passing all of the parameters above manually + ) + + result_image = results[0][0] + + # Results are image and seed, unwrap for now and ignore the seed + # TODO: pre-seed? + # TODO: can this return multiple results? Should it? + image_type = ImageType.RESULT + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, result_image) + return ImageOutput( + image = ImageField(image_type = image_type, image_name = image_name) + ) diff --git a/ldm/invoke/app/invocations/image.py b/ldm/invoke/app/invocations/image.py new file mode 100644 index 0000000000..cb326b1bb7 --- /dev/null +++ b/ldm/invoke/app/invocations/image.py @@ -0,0 +1,219 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from datetime import datetime, timezone +from typing import Literal, Optional +import numpy +from pydantic import Field, BaseModel +from PIL import Image, ImageOps, ImageFilter +from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext +from ..services.image_storage import ImageType +from ..services.invocation_services import InvocationServices + + +class ImageField(BaseModel): + """An image field used for passing image objects between invocations""" + image_type: str = Field(default=ImageType.RESULT, description="The type of the image") + image_name: Optional[str] = Field(default=None, description="The name of the image") + + +class ImageOutput(BaseInvocationOutput): + """Base class for invocations that output an image""" + type: Literal['image'] = 'image' + + image: ImageField = Field(default=None, description="The output image") + + +class MaskOutput(BaseInvocationOutput): + """Base class for invocations that output a mask""" + type: Literal['mask'] = 'mask' + + mask: ImageField = Field(default=None, description="The output mask") + + +# TODO: this isn't really necessary anymore +class LoadImageInvocation(BaseInvocation): + """Load an image from a filename and provide it as output.""" + type: Literal['load_image'] = 'load_image' + + # Inputs + image_type: ImageType = Field(description="The type of the image") + image_name: str = Field(description="The name of the image") + + def invoke(self, context: InvocationContext) -> ImageOutput: + return ImageOutput( + image = ImageField(image_type = self.image_type, image_name = self.image_name) + ) + + +class ShowImageInvocation(BaseInvocation): + """Displays a provided image, and passes it forward in the pipeline.""" + type: Literal['show_image'] = 'show_image' + + # Inputs + image: ImageField = Field(default=None, description="The image to show") + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get(self.image.image_type, self.image.image_name) + if image: + image.show() + + # TODO: how to handle failure? + + return ImageOutput( + image = ImageField(image_type = self.image.image_type, image_name = self.image.image_name) + ) + + +class CropImageInvocation(BaseInvocation): + """Crops an image to a specified box. The box can be outside of the image.""" + type: Literal['crop'] = 'crop' + + # Inputs + image: ImageField = Field(default=None, description="The image to crop") + x: int = Field(default=0, description="The left x coordinate of the crop rectangle") + y: int = Field(default=0, description="The top y coordinate of the crop rectangle") + width: int = Field(default=512, gt=0, description="The width of the crop rectangle") + height: int = Field(default=512, gt=0, description="The height of the crop rectangle") + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get(self.image.image_type, self.image.image_name) + + image_crop = Image.new(mode = 'RGBA', size = (self.width, self.height), color = (0, 0, 0, 0)) + image_crop.paste(image, (-self.x, -self.y)) + + image_type = ImageType.INTERMEDIATE + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, image_crop) + return ImageOutput( + image = ImageField(image_type = image_type, image_name = image_name) + ) + + +class PasteImageInvocation(BaseInvocation): + """Pastes an image into another image.""" + type: Literal['paste'] = 'paste' + + # Inputs + base_image: ImageField = Field(default=None, description="The base image") + image: ImageField = Field(default=None, description="The image to paste") + mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting") + x: int = Field(default=0, description="The left x coordinate at which to paste the image") + y: int = Field(default=0, description="The top y coordinate at which to paste the image") + + def invoke(self, context: InvocationContext) -> ImageOutput: + base_image = context.services.images.get(self.base_image.image_type, self.base_image.image_name) + image = context.services.images.get(self.image.image_type, self.image.image_name) + mask = None if self.mask is None else ImageOps.invert(services.images.get(self.mask.image_type, self.mask.image_name)) + # TODO: probably shouldn't invert mask here... should user be required to do it? + + min_x = min(0, self.x) + min_y = min(0, self.y) + max_x = max(base_image.width, image.width + self.x) + max_y = max(base_image.height, image.height + self.y) + + new_image = Image.new(mode = 'RGBA', size = (max_x - min_x, max_y - min_y), color = (0, 0, 0, 0)) + new_image.paste(base_image, (abs(min_x), abs(min_y))) + new_image.paste(image, (max(0, self.x), max(0, self.y)), mask = mask) + + image_type = ImageType.RESULT + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, new_image) + return ImageOutput( + image = ImageField(image_type = image_type, image_name = image_name) + ) + + +class MaskFromAlphaInvocation(BaseInvocation): + """Extracts the alpha channel of an image as a mask.""" + type: Literal['tomask'] = 'tomask' + + # Inputs + image: ImageField = Field(default=None, description="The image to create the mask from") + invert: bool = Field(default=False, description="Whether or not to invert the mask") + + def invoke(self, context: InvocationContext) -> MaskOutput: + image = context.services.images.get(self.image.image_type, self.image.image_name) + + image_mask = image.split()[-1] + if self.invert: + image_mask = ImageOps.invert(image_mask) + + image_type = ImageType.INTERMEDIATE + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, image_mask) + return MaskOutput( + mask = ImageField(image_type = image_type, image_name = image_name) + ) + + +class BlurInvocation(BaseInvocation): + """Blurs an image""" + type: Literal['blur'] = 'blur' + + # Inputs + image: ImageField = Field(default=None, description="The image to blur") + radius: float = Field(default=8.0, ge=0, description="The blur radius") + blur_type: Literal['gaussian', 'box'] = Field(default='gaussian', description="The type of blur") + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get(self.image.image_type, self.image.image_name) + + blur = ImageFilter.GaussianBlur(self.radius) if self.blur_type == 'gaussian' else ImageFilter.BoxBlur(self.radius) + blur_image = image.filter(blur) + + image_type = ImageType.INTERMEDIATE + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, blur_image) + return ImageOutput( + image = ImageField(image_type = image_type, image_name = image_name) + ) + + +class LerpInvocation(BaseInvocation): + """Linear interpolation of all pixels of an image""" + type: Literal['lerp'] = 'lerp' + + # Inputs + image: ImageField = Field(default=None, description="The image to lerp") + min: int = Field(default=0, ge=0, le=255, description="The minimum output value") + max: int = Field(default=255, ge=0, le=255, description="The maximum output value") + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get(self.image.image_type, self.image.image_name) + + image_arr = numpy.asarray(image, dtype=numpy.float32) / 255 + image_arr = image_arr * (self.max - self.min) + self.max + + lerp_image = Image.fromarray(numpy.uint8(image_arr)) + + image_type = ImageType.INTERMEDIATE + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, lerp_image) + return ImageOutput( + image = ImageField(image_type = image_type, image_name = image_name) + ) + + +class InverseLerpInvocation(BaseInvocation): + """Inverse linear interpolation of all pixels of an image""" + type: Literal['ilerp'] = 'ilerp' + + # Inputs + image: ImageField = Field(default=None, description="The image to lerp") + min: int = Field(default=0, ge=0, le=255, description="The minimum input value") + max: int = Field(default=255, ge=0, le=255, description="The maximum input value") + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get(self.image.image_type, self.image.image_name) + + image_arr = numpy.asarray(image, dtype=numpy.float32) + image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255 + + ilerp_image = Image.fromarray(numpy.uint8(image_arr)) + + image_type = ImageType.INTERMEDIATE + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, ilerp_image) + return ImageOutput( + image = ImageField(image_type = image_type, image_name = image_name) + ) diff --git a/ldm/invoke/app/invocations/prompt.py b/ldm/invoke/app/invocations/prompt.py new file mode 100644 index 0000000000..029cad9660 --- /dev/null +++ b/ldm/invoke/app/invocations/prompt.py @@ -0,0 +1,9 @@ +from typing import Literal +from pydantic.fields import Field +from .baseinvocation import BaseInvocationOutput + +class PromptOutput(BaseInvocationOutput): + """Base class for invocations that output a prompt""" + type: Literal['prompt'] = 'prompt' + + prompt: str = Field(default=None, description="The output prompt") diff --git a/ldm/invoke/app/invocations/reconstruct.py b/ldm/invoke/app/invocations/reconstruct.py new file mode 100644 index 0000000000..98201ce837 --- /dev/null +++ b/ldm/invoke/app/invocations/reconstruct.py @@ -0,0 +1,36 @@ +from datetime import datetime, timezone +from typing import Literal, Union +from pydantic import Field +from .image import ImageField, ImageOutput +from .baseinvocation import BaseInvocation, InvocationContext +from ..services.image_storage import ImageType +from ..services.invocation_services import InvocationServices + + +class RestoreFaceInvocation(BaseInvocation): + """Restores faces in an image.""" + type: Literal['restore_face'] = 'restore_face' + + # Inputs + image: Union[ImageField,None] = Field(description="The input image") + strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration") + + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get(self.image.image_type, self.image.image_name) + results = context.services.generate.upscale_and_reconstruct( + image_list = [[image, 0]], + upscale = None, + strength = self.strength, # GFPGAN strength + save_original = False, + image_callback = None, + ) + + # Results are image and seed, unwrap for now + # TODO: can this return multiple results? + image_type = ImageType.RESULT + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, results[0][0]) + return ImageOutput( + image = ImageField(image_type = image_type, image_name = image_name) + ) diff --git a/ldm/invoke/app/invocations/upscale.py b/ldm/invoke/app/invocations/upscale.py new file mode 100644 index 0000000000..1df8c44ea8 --- /dev/null +++ b/ldm/invoke/app/invocations/upscale.py @@ -0,0 +1,38 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from datetime import datetime, timezone +from typing import Literal, Union +from pydantic import Field +from .image import ImageField, ImageOutput +from .baseinvocation import BaseInvocation, InvocationContext +from ..services.image_storage import ImageType +from ..services.invocation_services import InvocationServices + + +class UpscaleInvocation(BaseInvocation): + """Upscales an image.""" + type: Literal['upscale'] = 'upscale' + + # Inputs + image: Union[ImageField,None] = Field(description="The input image", default=None) + strength: float = Field(default=0.75, gt=0, le=1, description="The strength") + level: Literal[2,4] = Field(default=2, description = "The upscale level") + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get(self.image.image_type, self.image.image_name) + results = context.services.generate.upscale_and_reconstruct( + image_list = [[image, 0]], + upscale = (self.level, self.strength), + strength = 0.0, # GFPGAN strength + save_original = False, + image_callback = None, + ) + + # Results are image and seed, unwrap for now + # TODO: can this return multiple results? + image_type = ImageType.RESULT + image_name = context.services.images.create_name(context.graph_execution_state_id, self.id) + context.services.images.save(image_type, image_name, results[0][0]) + return ImageOutput( + image = ImageField(image_type = image_type, image_name = image_name) + ) diff --git a/ldm/invoke/app/services/__init__.py b/ldm/invoke/app/services/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ldm/invoke/app/services/events.py b/ldm/invoke/app/services/events.py new file mode 100644 index 0000000000..7b850b61ac --- /dev/null +++ b/ldm/invoke/app/services/events.py @@ -0,0 +1,78 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from typing import Any, Dict + + +class EventServiceBase: + session_event: str = 'session_event' + + """Basic event bus, to have an empty stand-in when not needed""" + def dispatch(self, event_name: str, payload: Any) -> None: + pass + + def __emit_session_event(self, + event_name: str, + payload: Dict) -> None: + self.dispatch( + event_name = EventServiceBase.session_event, + payload = dict( + event = event_name, + data = payload + ) + ) + + # Define events here for every event in the system. + # This will make them easier to integrate until we find a schema generator. + def emit_generator_progress(self, + graph_execution_state_id: str, + invocation_id: str, + step: int, + percent: float + ) -> None: + """Emitted when there is generation progress""" + self.__emit_session_event( + event_name = 'generator_progress', + payload = dict( + graph_execution_state_id = graph_execution_state_id, + invocation_id = invocation_id, + step = step, + percent = percent + ) + ) + + def emit_invocation_complete(self, + graph_execution_state_id: str, + invocation_id: str, + result: Dict + ) -> None: + """Emitted when an invocation has completed""" + self.__emit_session_event( + event_name = 'invocation_complete', + payload = dict( + graph_execution_state_id = graph_execution_state_id, + invocation_id = invocation_id, + result = result + ) + ) + + def emit_invocation_started(self, + graph_execution_state_id: str, + invocation_id: str + ) -> None: + """Emitted when an invocation has started""" + self.__emit_session_event( + event_name = 'invocation_started', + payload = dict( + graph_execution_state_id = graph_execution_state_id, + invocation_id = invocation_id + ) + ) + + def emit_graph_execution_complete(self, graph_execution_state_id: str) -> None: + """Emitted when a session has completed all invocations""" + self.__emit_session_event( + event_name = 'graph_execution_state_complete', + payload = dict( + graph_execution_state_id = graph_execution_state_id + ) + ) diff --git a/ldm/invoke/app/services/generate_initializer.py b/ldm/invoke/app/services/generate_initializer.py new file mode 100644 index 0000000000..39c0fe491e --- /dev/null +++ b/ldm/invoke/app/services/generate_initializer.py @@ -0,0 +1,233 @@ +from argparse import Namespace +import os +import sys +import traceback + +from ...model_manager import ModelManager + +from ...globals import Globals +from ....generate import Generate +import ldm.invoke + + +# TODO: most of this code should be split into individual services as the Generate.py code is deprecated +def get_generate(args, config) -> Generate: + if not args.conf: + config_file = os.path.join(Globals.root,'configs','models.yaml') + if not os.path.exists(config_file): + report_model_error(args, FileNotFoundError(f"The file {config_file} could not be found.")) + + print(f'>> {ldm.invoke.__app_name__}, version {ldm.invoke.__version__}') + print(f'>> InvokeAI runtime directory is "{Globals.root}"') + + # these two lines prevent a horrible warning message from appearing + # when the frozen CLIP tokenizer is imported + import transformers # type: ignore + transformers.logging.set_verbosity_error() + import diffusers + diffusers.logging.set_verbosity_error() + + # Loading Face Restoration and ESRGAN Modules + gfpgan,codeformer,esrgan = load_face_restoration(args) + + # normalize the config directory relative to root + if not os.path.isabs(args.conf): + args.conf = os.path.normpath(os.path.join(Globals.root,args.conf)) + + if args.embeddings: + if not os.path.isabs(args.embedding_path): + embedding_path = os.path.normpath(os.path.join(Globals.root,args.embedding_path)) + else: + embedding_path = args.embedding_path + else: + embedding_path = None + + # migrate legacy models + ModelManager.migrate_models() + + # load the infile as a list of lines + if args.infile: + try: + if os.path.isfile(args.infile): + infile = open(args.infile, 'r', encoding='utf-8') + elif args.infile == '-': # stdin + infile = sys.stdin + else: + raise FileNotFoundError(f'{args.infile} not found.') + except (FileNotFoundError, IOError) as e: + print(f'{e}. Aborting.') + sys.exit(-1) + + # creating a Generate object: + try: + gen = Generate( + conf = args.conf, + model = args.model, + sampler_name = args.sampler_name, + embedding_path = embedding_path, + full_precision = args.full_precision, + precision = args.precision, + gfpgan = gfpgan, + codeformer = codeformer, + esrgan = esrgan, + free_gpu_mem = args.free_gpu_mem, + safety_checker = args.safety_checker, + max_loaded_models = args.max_loaded_models, + ) + except (FileNotFoundError, TypeError, AssertionError) as e: + report_model_error(opt,e) + except (IOError, KeyError) as e: + print(f'{e}. Aborting.') + sys.exit(-1) + + if args.seamless: + print(">> changed to seamless tiling mode") + + # preload the model + try: + gen.load_model() + except KeyError: + pass + except Exception as e: + report_model_error(args, e) + + # try to autoconvert new models + # autoimport new .ckpt files + if path := args.autoconvert: + gen.model_manager.autoconvert_weights( + conf_path=args.conf, + weights_directory=path, + ) + + return gen + + +def load_face_restoration(opt): + try: + gfpgan, codeformer, esrgan = None, None, None + if opt.restore or opt.esrgan: + from ldm.invoke.restoration import Restoration + restoration = Restoration() + if opt.restore: + gfpgan, codeformer = restoration.load_face_restore_models(opt.gfpgan_model_path) + else: + print('>> Face restoration disabled') + if opt.esrgan: + esrgan = restoration.load_esrgan(opt.esrgan_bg_tile) + else: + print('>> Upscaling disabled') + else: + print('>> Face restoration and upscaling disabled') + except (ModuleNotFoundError, ImportError): + print(traceback.format_exc(), file=sys.stderr) + print('>> You may need to install the ESRGAN and/or GFPGAN modules') + return gfpgan,codeformer,esrgan + + +def report_model_error(opt:Namespace, e:Exception): + print(f'** An error occurred while attempting to initialize the model: "{str(e)}"') + print('** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models.') + yes_to_all = os.environ.get('INVOKE_MODEL_RECONFIGURE') + if yes_to_all: + print('** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE') + else: + response = input('Do you want to run invokeai-configure script to select and/or reinstall models? [y] ') + if response.startswith(('n', 'N')): + return + + print('invokeai-configure is launching....\n') + + # Match arguments that were set on the CLI + # only the arguments accepted by the configuration script are parsed + root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else [] + config = ["--config", opt.conf] if opt.conf is not None else [] + previous_args = sys.argv + sys.argv = [ 'invokeai-configure' ] + sys.argv.extend(root_dir) + sys.argv.extend(config) + if yes_to_all is not None: + for arg in yes_to_all.split(): + sys.argv.append(arg) + + from ldm.invoke.config import invokeai_configure + invokeai_configure.main() + # TODO: Figure out how to restart + # print('** InvokeAI will now restart') + # sys.argv = previous_args + # main() # would rather do a os.exec(), but doesn't exist? + # sys.exit(0) + + +# Temporary initializer for Generate until we migrate off of it +def old_get_generate(args, config) -> Generate: + # TODO: Remove the need for globals + from ldm.invoke.globals import Globals + + # alert - setting globals here + Globals.root = os.path.expanduser(args.root_dir or os.environ.get('INVOKEAI_ROOT') or os.path.abspath('.')) + Globals.try_patchmatch = args.patchmatch + + print(f'>> InvokeAI runtime directory is "{Globals.root}"') + + # these two lines prevent a horrible warning message from appearing + # when the frozen CLIP tokenizer is imported + import transformers + transformers.logging.set_verbosity_error() + + # Loading Face Restoration and ESRGAN Modules + gfpgan, codeformer, esrgan = None, None, None + try: + if config.restore or config.esrgan: + from ldm.invoke.restoration import Restoration + restoration = Restoration() + if config.restore: + gfpgan, codeformer = restoration.load_face_restore_models(config.gfpgan_model_path) + else: + print('>> Face restoration disabled') + if config.esrgan: + esrgan = restoration.load_esrgan(config.esrgan_bg_tile) + else: + print('>> Upscaling disabled') + else: + print('>> Face restoration and upscaling disabled') + except (ModuleNotFoundError, ImportError): + print(traceback.format_exc(), file=sys.stderr) + print('>> You may need to install the ESRGAN and/or GFPGAN modules') + + # normalize the config directory relative to root + if not os.path.isabs(config.conf): + config.conf = os.path.normpath(os.path.join(Globals.root,config.conf)) + + if config.embeddings: + if not os.path.isabs(config.embedding_path): + embedding_path = os.path.normpath(os.path.join(Globals.root,config.embedding_path)) + else: + embedding_path = None + + + # TODO: lazy-initialize this by wrapping it + try: + generate = Generate( + conf = config.conf, + model = config.model, + sampler_name = config.sampler_name, + embedding_path = embedding_path, + full_precision = config.full_precision, + precision = config.precision, + gfpgan = gfpgan, + codeformer = codeformer, + esrgan = esrgan, + free_gpu_mem = config.free_gpu_mem, + safety_checker = config.safety_checker, + max_loaded_models = config.max_loaded_models, + ) + except (FileNotFoundError, TypeError, AssertionError): + #emergency_model_reconfigure() # TODO? + sys.exit(-1) + except (IOError, KeyError) as e: + print(f'{e}. Aborting.') + sys.exit(-1) + + generate.free_gpu_mem = config.free_gpu_mem + + return generate diff --git a/ldm/invoke/app/services/graph.py b/ldm/invoke/app/services/graph.py new file mode 100644 index 0000000000..8d1583fc8b --- /dev/null +++ b/ldm/invoke/app/services/graph.py @@ -0,0 +1,797 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +import copy +import itertools +from types import NoneType +import uuid +import networkx as nx +from pydantic import BaseModel, validator +from pydantic.fields import Field +from typing import Any, Literal, Optional, Union, get_args, get_origin, get_type_hints, Annotated + +from .invocation_services import InvocationServices +from ..invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext +from ..invocations import * + + +class EdgeConnection(BaseModel): + node_id: str = Field(description="The id of the node for this edge connection") + field: str = Field(description="The field for this connection") + + def __eq__(self, other): + return (isinstance(other, self.__class__) and + getattr(other, 'node_id', None) == self.node_id and + getattr(other, 'field', None) == self.field) + + def __hash__(self): + return hash(f'{self.node_id}.{self.field}') + + +def get_output_field(node: BaseInvocation, field: str) -> Any: + node_type = type(node) + node_outputs = get_type_hints(node_type.get_output_type()) + node_output_field = node_outputs.get(field) or None + return node_output_field + + +def get_input_field(node: BaseInvocation, field: str) -> Any: + node_type = type(node) + node_inputs = get_type_hints(node_type) + node_input_field = node_inputs.get(field) or None + return node_input_field + + +def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool: + if not from_type: + return False + if not to_type: + return False + + # TODO: this is pretty forgiving on generic types. Clean that up (need to handle optionals and such) + if from_type and to_type: + # Ports are compatible + if (from_type == to_type or + from_type == Any or + to_type == Any or + Any in get_args(from_type) or + Any in get_args(to_type)): + return True + + if from_type in get_args(to_type): + return True + + if to_type in get_args(from_type): + return True + + if not issubclass(from_type, to_type): + return False + else: + return False + + return True + + +def are_connections_compatible( + from_node: BaseInvocation, + from_field: str, + to_node: BaseInvocation, + to_field: str) -> bool: + """Determines if a connection between fields of two nodes is compatible.""" + + # TODO: handle iterators and collectors + from_node_field = get_output_field(from_node, from_field) + to_node_field = get_input_field(to_node, to_field) + + return are_connection_types_compatible(from_node_field, to_node_field) + + +class NodeAlreadyInGraphError(Exception): + pass + + +class InvalidEdgeError(Exception): + pass + +class NodeNotFoundError(Exception): + pass + +class NodeAlreadyExecutedError(Exception): + pass + + +# TODO: Create and use an Empty output? +class GraphInvocationOutput(BaseInvocationOutput): + type: Literal['graph_output'] = 'graph_output' + + +# TODO: Fill this out and move to invocations +class GraphInvocation(BaseInvocation): + type: Literal['graph'] = 'graph' + + # TODO: figure out how to create a default here + graph: 'Graph' = Field(description="The graph to run", default=None) + + def invoke(self, context: InvocationContext) -> GraphInvocationOutput: + """Invoke with provided services and return outputs.""" + return GraphInvocationOutput() + + +class IterateInvocationOutput(BaseInvocationOutput): + """Used to connect iteration outputs. Will be expanded to a specific output.""" + type: Literal['iterate_output'] = 'iterate_output' + + item: Any = Field(description="The item being iterated over") + + +# TODO: Fill this out and move to invocations +class IterateInvocation(BaseInvocation): + type: Literal['iterate'] = 'iterate' + + collection: list[Any] = Field(description="The list of items to iterate over", default_factory=list) + index: int = Field(description="The index, will be provided on executed iterators", default=0) + + def invoke(self, context: InvocationContext) -> IterateInvocationOutput: + """Produces the outputs as values""" + return IterateInvocationOutput(item = self.collection[self.index]) + + +class CollectInvocationOutput(BaseInvocationOutput): + type: Literal['collect_output'] = 'collect_output' + + collection: list[Any] = Field(description="The collection of input items") + + +class CollectInvocation(BaseInvocation): + """Collects values into a collection""" + type: Literal['collect'] = 'collect' + + item: Any = Field(description="The item to collect (all inputs must be of the same type)", default=None) + collection: list[Any] = Field(description="The collection, will be provided on execution", default_factory=list) + + def invoke(self, context: InvocationContext) -> CollectInvocationOutput: + """Invoke with provided services and return outputs.""" + return CollectInvocationOutput(collection = copy.copy(self.collection)) + + +InvocationsUnion = Union[BaseInvocation.get_invocations()] +InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] + + +class Graph(BaseModel): + id: str = Field(description="The id of this graph", default_factory=uuid.uuid4) + # TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me + nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(description="The nodes in this graph", default_factory=dict) + edges: list[tuple[EdgeConnection,EdgeConnection]] = Field(description="The connections between nodes and their fields in this graph", default_factory=list) + + def add_node(self, node: BaseInvocation) -> None: + """Adds a node to a graph + + :raises NodeAlreadyInGraphError: the node is already present in the graph. + """ + + if node.id in self.nodes: + raise NodeAlreadyInGraphError() + + self.nodes[node.id] = node + + + def _get_graph_and_node(self, node_path: str) -> tuple['Graph', str]: + """Returns the graph and node id for a node path.""" + # Materialized graphs may have nodes at the top level + if node_path in self.nodes: + return (self, node_path) + + node_id = node_path if '.' not in node_path else node_path[:node_path.index('.')] + if node_id not in self.nodes: + raise NodeNotFoundError(f'Node {node_path} not found in graph') + + node = self.nodes[node_id] + + if not isinstance(node, GraphInvocation): + # There's more node path left but this isn't a graph - failure + raise NodeNotFoundError('Node path terminated early at a non-graph node') + + return node.graph._get_graph_and_node(node_path[node_path.index('.')+1:]) + + + def delete_node(self, node_path: str) -> None: + """Deletes a node from a graph""" + + try: + graph, node_id = self._get_graph_and_node(node_path) + + # Delete edges for this node + input_edges = self._get_input_edges_and_graphs(node_path) + output_edges = self._get_output_edges_and_graphs(node_path) + + for edge_graph,_,edge in input_edges: + edge_graph.delete_edge(edge) + + for edge_graph,_,edge in output_edges: + edge_graph.delete_edge(edge) + + del graph.nodes[node_id] + + except NodeNotFoundError: + pass # Ignore, not doesn't exist (should this throw?) + + + def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None: + """Adds an edge to a graph + + :raises InvalidEdgeError: the provided edge is invalid. + """ + + if self._is_edge_valid(edge) and edge not in self.edges: + self.edges.append(edge) + else: + raise InvalidEdgeError() + + + def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None: + """Deletes an edge from a graph""" + + try: + self.edges.remove(edge) + except KeyError: + pass + + + def is_valid(self) -> bool: + """Validates the graph.""" + + # Validate all subgraphs + for gn in (n for n in self.nodes.values() if isinstance(n, GraphInvocation)): + if not gn.graph.is_valid(): + return False + + # Validate all edges reference nodes in the graph + node_ids = set([e[0].node_id for e in self.edges]+[e[1].node_id for e in self.edges]) + if not all((self.has_node(node_id) for node_id in node_ids)): + return False + + # Validate there are no cycles + g = self.nx_graph_flat() + if not nx.is_directed_acyclic_graph(g): + return False + + # Validate all edge connections are valid + if not all((are_connections_compatible( + self.get_node(e[0].node_id), e[0].field, + self.get_node(e[1].node_id), e[1].field + ) for e in self.edges)): + return False + + # Validate all iterators + # TODO: may need to validate all iterators in subgraphs so edge connections in parent graphs will be available + if not all((self._is_iterator_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, IterateInvocation))): + return False + + # Validate all collectors + # TODO: may need to validate all collectors in subgraphs so edge connections in parent graphs will be available + if not all((self._is_collector_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, CollectInvocation))): + return False + + return True + + def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool: + """Validates that a new edge doesn't create a cycle in the graph""" + + # Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly) + try: + from_node = self.get_node(edge[0].node_id) + to_node = self.get_node(edge[1].node_id) + except NodeNotFoundError: + return False + + # Validate that an edge to this node+field doesn't already exist + input_edges = self._get_input_edges(edge[1].node_id, edge[1].field) + if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation): + return False + + # Validate that no cycles would be created + g = self.nx_graph_flat() + g.add_edge(edge[0].node_id, edge[1].node_id) + if not nx.is_directed_acyclic_graph(g): + return False + + # Validate that the field types are compatible + if not are_connections_compatible(from_node, edge[0].field, to_node, edge[1].field): + return False + + # Validate if iterator output type matches iterator input type (if this edge results in both being set) + if isinstance(to_node, IterateInvocation) and edge[1].field == 'collection': + if not self._is_iterator_connection_valid(edge[1].node_id, new_input = edge[0]): + return False + + # Validate if iterator input type matches output type (if this edge results in both being set) + if isinstance(from_node, IterateInvocation) and edge[0].field == 'item': + if not self._is_iterator_connection_valid(edge[0].node_id, new_output = edge[1]): + return False + + # Validate if collector input type matches output type (if this edge results in both being set) + if isinstance(to_node, CollectInvocation) and edge[1].field == 'item': + if not self._is_collector_connection_valid(edge[1].node_id, new_input = edge[0]): + return False + + # Validate if collector output type matches input type (if this edge results in both being set) + if isinstance(from_node, CollectInvocation) and edge[0].field == 'collection': + if not self._is_collector_connection_valid(edge[0].node_id, new_output = edge[1]): + return False + + return True + + def has_node(self, node_path: str) -> bool: + """Determines whether or not a node exists in the graph.""" + try: + n = self.get_node(node_path) + if n is not None: + return True + else: + return False + except NodeNotFoundError: + return False + + def get_node(self, node_path: str) -> InvocationsUnion: + """Gets a node from the graph using a node path.""" + # Materialized graphs may have nodes at the top level + graph, node_id = self._get_graph_and_node(node_path) + return graph.nodes[node_id] + + + def _get_node_path(self, node_id: str, prefix: Optional[str] = None) -> str: + return node_id if prefix is None or prefix == '' else f'{prefix}.{node_id}' + + + def update_node(self, node_path: str, new_node: BaseInvocation) -> None: + """Updates a node in the graph.""" + graph, node_id = self._get_graph_and_node(node_path) + node = graph.nodes[node_id] + + # Ensure the node type matches the new node + if type(node) != type(new_node): + raise TypeError(f'Node {node_path} is type {type(node)} but new node is type {type(new_node)}') + + # Ensure the new id is either the same or is not in the graph + prefix = None if '.' not in node_path else node_path[:node_path.rindex('.')] + new_path = self._get_node_path(new_node.id, prefix = prefix) + if new_node.id != node.id and self.has_node(new_path): + raise NodeAlreadyInGraphError('Node with id {new_node.id} already exists in graph') + + # Set the new node in the graph + graph.nodes[new_node.id] = new_node + if new_node.id != node.id: + input_edges = self._get_input_edges_and_graphs(node_path) + output_edges = self._get_output_edges_and_graphs(node_path) + + # Delete node and all edges + graph.delete_node(node_path) + + # Create new edges for each input and output + for graph,_,edge in input_edges: + # Remove the graph prefix from the node path + new_graph_node_path = new_node.id if '.' not in edge[1].node_id else f'{edge[1].node_id[edge[1].node_id.rindex("."):]}.{new_node.id}' + graph.add_edge((edge[0], EdgeConnection(node_id = new_graph_node_path, field = edge[1].field))) + + for graph,_,edge in output_edges: + # Remove the graph prefix from the node path + new_graph_node_path = new_node.id if '.' not in edge[0].node_id else f'{edge[0].node_id[edge[0].node_id.rindex("."):]}.{new_node.id}' + graph.add_edge((EdgeConnection(node_id = new_graph_node_path, field = edge[0].field), edge[1])) + + + def _get_input_edges(self, node_path: str, field: Optional[str] = None) -> list[tuple[EdgeConnection,EdgeConnection]]: + """Gets all input edges for a node""" + edges = self._get_input_edges_and_graphs(node_path) + + # Filter to edges that match the field + filtered_edges = (e for e in edges if field is None or e[2][1].field == field) + + # Create full node paths for each edge + return [(EdgeConnection(node_id = self._get_node_path(e[0].node_id, prefix = prefix), field=e[0].field), EdgeConnection(node_id = self._get_node_path(e[1].node_id, prefix = prefix), field=e[1].field)) for _,prefix,e in filtered_edges] + + + def _get_input_edges_and_graphs(self, node_path: str, prefix: Optional[str] = None) -> list[tuple['Graph', str, tuple[EdgeConnection,EdgeConnection]]]: + """Gets all input edges for a node along with the graph they are in and the graph's path""" + edges = list() + + # Return any input edges that appear in this graph + edges.extend([(self, prefix, e) for e in self.edges if e[1].node_id == node_path]) + + node_id = node_path if '.' not in node_path else node_path[:node_path.index('.')] + node = self.nodes[node_id] + + if isinstance(node, GraphInvocation): + graph = node.graph + graph_path = node.id if prefix is None or prefix == '' else self._get_node_path(node.id, prefix = prefix) + graph_edges = graph._get_input_edges_and_graphs(node_path[(len(node_id)+1):], prefix=graph_path) + edges.extend(graph_edges) + + return edges + + + def _get_output_edges(self, node_path: str, field: str) -> list[tuple[EdgeConnection,EdgeConnection]]: + """Gets all output edges for a node""" + edges = self._get_output_edges_and_graphs(node_path) + + # Filter to edges that match the field + filtered_edges = (e for e in edges if e[2][0].field == field) + + # Create full node paths for each edge + return [(EdgeConnection(node_id = self._get_node_path(e[0].node_id, prefix = prefix), field=e[0].field), EdgeConnection(node_id = self._get_node_path(e[1].node_id, prefix = prefix), field=e[1].field)) for _,prefix,e in filtered_edges] + + + def _get_output_edges_and_graphs(self, node_path: str, prefix: Optional[str] = None) -> list[tuple['Graph', str, tuple[EdgeConnection,EdgeConnection]]]: + """Gets all output edges for a node along with the graph they are in and the graph's path""" + edges = list() + + # Return any input edges that appear in this graph + edges.extend([(self, prefix, e) for e in self.edges if e[0].node_id == node_path]) + + node_id = node_path if '.' not in node_path else node_path[:node_path.index('.')] + node = self.nodes[node_id] + + if isinstance(node, GraphInvocation): + graph = node.graph + graph_path = node.id if prefix is None or prefix == '' else self._get_node_path(node.id, prefix = prefix) + graph_edges = graph._get_output_edges_and_graphs(node_path[(len(node_id)+1):], prefix=graph_path) + edges.extend(graph_edges) + + return edges + + + def _is_iterator_connection_valid(self, node_path: str, new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None) -> bool: + inputs = list([e[0] for e in self._get_input_edges(node_path, 'collection')]) + outputs = list([e[1] for e in self._get_output_edges(node_path, 'item')]) + + if new_input is not None: + inputs.append(new_input) + if new_output is not None: + outputs.append(new_output) + + # Only one input is allowed for iterators + if len(inputs) > 1: + return False + + # Get input and output fields (the fields linked to the iterator's input/output) + input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field) + output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs]) + + # Input type must be a list + if get_origin(input_field) != list: + return False + + # Validate that all outputs match the input type + input_field_item_type = get_args(input_field)[0] + if not all((are_connection_types_compatible(input_field_item_type, f) for f in output_fields)): + return False + + return True + + def _is_collector_connection_valid(self, node_path: str, new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None) -> bool: + inputs = list([e[0] for e in self._get_input_edges(node_path, 'item')]) + outputs = list([e[1] for e in self._get_output_edges(node_path, 'collection')]) + + if new_input is not None: + inputs.append(new_input) + if new_output is not None: + outputs.append(new_output) + + # Get input and output fields (the fields linked to the iterator's input/output) + input_fields = list([get_output_field(self.get_node(e.node_id), e.field) for e in inputs]) + output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs]) + + # Validate that all inputs are derived from or match a single type + input_field_types = set([t for input_field in input_fields for t in ([input_field] if get_origin(input_field) == None else get_args(input_field)) if t != NoneType]) # Get unique types + type_tree = nx.DiGraph() + type_tree.add_nodes_from(input_field_types) + type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])]) + type_degrees = type_tree.in_degree(type_tree.nodes) + if sum((t[1] == 0 for t in type_degrees)) != 1: + return False # There is more than one root type + + # Get the input root type + input_root_type = next(t[0] for t in type_degrees if t[1] == 0) + + # Verify that all outputs are lists + if not all((get_origin(f) == list for f in output_fields)): + return False + + # Verify that all outputs match the input type (are a base class or the same class) + if not all((issubclass(input_root_type, get_args(f)[0]) for f in output_fields)): + return False + + return True + + def nx_graph(self) -> nx.DiGraph: + """Returns a NetworkX DiGraph representing the layout of this graph""" + # TODO: Cache this? + g = nx.DiGraph() + g.add_nodes_from([n for n in self.nodes.keys()]) + g.add_edges_from(set([(e[0].node_id, e[1].node_id) for e in self.edges])) + return g + + def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph: + """Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)""" + g = nx_graph or nx.DiGraph() + + # Add all nodes from this graph except graph/iteration nodes + g.add_nodes_from([self._get_node_path(n.id, prefix) for n in self.nodes.values() if not isinstance(n, GraphInvocation) and not isinstance(n, IterateInvocation)]) + + # Expand graph nodes + for sgn in (gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)): + sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix)) + + # TODO: figure out if iteration nodes need to be expanded + + unique_edges = set([(e[0].node_id, e[1].node_id) for e in self.edges]) + g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges]) + return g + + +class GraphExecutionState(BaseModel): + """Tracks the state of a graph execution""" + id: str = Field(description="The id of the execution state", default_factory=uuid.uuid4) + + # TODO: Store a reference to the graph instead of the actual graph? + graph: Graph = Field(description="The graph being executed") + + # The graph of materialized nodes + execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes", default_factory=Graph) + + # Nodes that have been executed + executed: set[str] = Field(description="The set of node ids that have been executed", default_factory=set) + executed_history: list[str] = Field(description="The list of node ids that have been executed, in order of execution", default_factory=list) + + # The results of executed nodes + results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(description="The results of node executions", default_factory=dict) + + # Map of prepared/executed nodes to their original nodes + prepared_source_mapping: dict[str, str] = Field(description="The map of prepared nodes to original graph nodes", default_factory=dict) + + # Map of original nodes to prepared nodes + source_prepared_mapping: dict[str, set[str]] = Field(description="The map of original graph nodes to prepared nodes", default_factory=dict) + + def next(self) -> BaseInvocation | None: + """Gets the next node ready to execute.""" + + # TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes + # possibly with a timeout? + + # If there are no prepared nodes, prepare some nodes + next_node = self._get_next_node() + if next_node is None: + prepared_id = self._prepare() + + # TODO: prepare multiple nodes at once? + # while prepared_id is not None and not isinstance(self.graph.nodes[prepared_id], IterateInvocation): + # prepared_id = self._prepare() + + if prepared_id is not None: + next_node = self._get_next_node() + + # Get values from edges + if next_node is not None: + self._prepare_inputs(next_node) + + # If next is still none, there's no next node, return None + return next_node + + def complete(self, node_id: str, output: InvocationOutputsUnion): + """Marks a node as complete""" + + if node_id not in self.execution_graph.nodes: + return # TODO: log error? + + # Mark node as executed + self.executed.add(node_id) + self.results[node_id] = output + + # Check if source node is complete (all prepared nodes are complete) + source_node = self.prepared_source_mapping[node_id] + prepared_nodes = self.source_prepared_mapping[source_node] + + if all([n in self.executed for n in prepared_nodes]): + self.executed.add(source_node) + self.executed_history.append(source_node) + + def is_complete(self) -> bool: + """Returns true if the graph is complete""" + return all((k in self.executed for k in self.graph.nodes)) + + def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]: + """Prepares an iteration node and connects all edges, returning the new node id""" + + node = self.graph.get_node(node_path) + + self_iteration_count = -1 + + # If this is an iterator node, we must create a copy for each iteration + if isinstance(node, IterateInvocation): + # Get input collection edge (should error if there are no inputs) + input_collection_edge = next(iter(self.graph._get_input_edges(node_path, 'collection'))) + input_collection_prepared_node_id = next(n[1] for n in iteration_node_map if n[0] == input_collection_edge[0].node_id) + input_collection_prepared_node_output = self.results[input_collection_prepared_node_id] + input_collection = getattr(input_collection_prepared_node_output, input_collection_edge[0].field) + self_iteration_count = len(input_collection) + + new_nodes = list() + if self_iteration_count == 0: + # TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid. + return new_nodes + + # Get all input edges + input_edges = self.graph._get_input_edges(node_path) + + # Create new edges for this iteration + # For collect nodes, this may contain multiple inputs to the same field + new_edges = list() + for edge in input_edges: + for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge[0].node_id): + new_edge = (EdgeConnection(node_id = input_node_id, field = edge[0].field), EdgeConnection(node_id = '', field = edge[1].field)) + new_edges.append(new_edge) + + # Create a new node (or one for each iteration of this iterator) + for i in (range(self_iteration_count) if self_iteration_count > 0 else [-1]): + # Create a new node + new_node = copy.deepcopy(node) + + # Create the node id (use a random uuid) + new_node.id = str(uuid.uuid4()) + + # Set the iteration index for iteration invocations + if isinstance(new_node, IterateInvocation): + new_node.index = i + + # Add to execution graph + self.execution_graph.add_node(new_node) + self.prepared_source_mapping[new_node.id] = node_path + if node_path not in self.source_prepared_mapping: + self.source_prepared_mapping[node_path] = set() + self.source_prepared_mapping[node_path].add(new_node.id) + + # Add new edges to execution graph + for edge in new_edges: + new_edge = (edge[0], EdgeConnection(node_id = new_node.id, field = edge[1].field)) + self.execution_graph.add_edge(new_edge) + + new_nodes.append(new_node.id) + + return new_nodes + + def _iterator_graph(self) -> nx.DiGraph: + """Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node""" + g = self.graph.nx_graph() + collectors = (n for n in self.graph.nodes if isinstance(self.graph.nodes[n], CollectInvocation)) + for c in collectors: + g.remove_edges_from(list(g.in_edges(c))) + return g + + + def _get_node_iterators(self, node_id: str) -> list[str]: + """Gets iterators for a node""" + g = self._iterator_graph() + iterators = [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.nodes[n], IterateInvocation)] + return iterators + + + def _prepare(self) -> Optional[str]: + # Get flattened source graph + g = self.graph.nx_graph_flat() + + # Find next unprepared node where all source nodes are executed + sorted_nodes = nx.topological_sort(g) + next_node_id = next((n for n in sorted_nodes if n not in self.source_prepared_mapping and all((e[0] in self.executed for e in g.in_edges(n)))), None) + + if next_node_id == None: + return None + + # Get all parents of the next node + next_node_parents = [e[0] for e in g.in_edges(next_node_id)] + + # Create execution nodes + next_node = self.graph.get_node(next_node_id) + new_node_ids = list() + if isinstance(next_node, CollectInvocation): + # Collapse all iterator input mappings and create a single execution node for the collect invocation + all_iteration_mappings = list(itertools.chain(*(((s,p) for p in self.source_prepared_mapping[s]) for s in next_node_parents))) + #all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings))) + create_results = self._create_execution_node(next_node_id, all_iteration_mappings) + if create_results is not None: + new_node_ids.extend(create_results) + else: # Iterators or normal nodes + # Get all iterator combinations for this node + # Will produce a list of lists of prepared iterator nodes, from which results can be iterated + iterator_nodes = self._get_node_iterators(next_node_id) + iterator_nodes_prepared = [list(self.source_prepared_mapping[n]) for n in iterator_nodes] + iterator_node_prepared_combinations = list(itertools.product(*iterator_nodes_prepared)) + + # Select the correct prepared parents for each iteration + # For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator + # TODO: Handle a node mapping to none + eg = self.execution_graph.nx_graph_flat() + prepared_parent_mappings = [[(n,self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations] + + # Create execution node for each iteration + for iteration_mappings in prepared_parent_mappings: + create_results = self._create_execution_node(next_node_id, iteration_mappings) + if create_results is not None: + new_node_ids.extend(create_results) + + return next(iter(new_node_ids), None) + + def _get_iteration_node(self, source_node_path: str, graph: nx.DiGraph, execution_graph: nx.DiGraph, prepared_iterator_nodes: list[str]) -> Optional[str]: + """Gets the prepared version of the specified source node that matches every iteration specified""" + prepared_nodes = self.source_prepared_mapping[source_node_path] + if len(prepared_nodes) == 1: + return next(iter(prepared_nodes)) + + # Check if the requested node is an iterator + prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None) + if prepared_iterator is not None: + return prepared_iterator + + # Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source) + iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes] + parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_path)] + + return next((n for n in prepared_nodes if all(pit for pit in parent_iterators if nx.has_path(execution_graph, pit[0], n))), None) + + def _get_next_node(self) -> Optional[BaseInvocation]: + g = self.execution_graph.nx_graph() + sorted_nodes = nx.topological_sort(g) + next_node = next((n for n in sorted_nodes if n not in self.executed), None) + if next_node is None: + return None + + return self.execution_graph.nodes[next_node] + + def _prepare_inputs(self, node: BaseInvocation): + input_edges = [e for e in self.execution_graph.edges if e[1].node_id == node.id] + if isinstance(node, CollectInvocation): + output_collection = [getattr(self.results[edge[0].node_id], edge[0].field) for edge in input_edges if edge[1].field == 'item'] + setattr(node, 'collection', output_collection) + else: + for edge in input_edges: + output_value = getattr(self.results[edge[0].node_id], edge[0].field) + setattr(node, edge[1].field, output_value) + + # TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state + def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool: + if not self._is_edge_valid(edge): + return False + + # Invalid if destination has already been prepared or executed + if edge[1].node_id in self.source_prepared_mapping: + return False + + # Otherwise, the edge is valid + return True + + def _is_node_updatable(self, node_id: str) -> bool: + # The node is updatable as long as it hasn't been prepared or executed + return node_id not in self.source_prepared_mapping + + def add_node(self, node: BaseInvocation) -> None: + self.graph.add_node(node) + + def update_node(self, node_path: str, new_node: BaseInvocation) -> None: + if not self._is_node_updatable(node_path): + raise NodeAlreadyExecutedError(f'Node {node_path} has already been prepared or executed and cannot be updated') + self.graph.update_node(node_path, new_node) + + def delete_node(self, node_path: str) -> None: + if not self._is_node_updatable(node_path): + raise NodeAlreadyExecutedError(f'Node {node_path} has already been prepared or executed and cannot be deleted') + self.graph.delete_node(node_path) + + def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None: + if not self._is_node_updatable(edge[1].node_id): + raise NodeAlreadyExecutedError(f'Destination node {edge[1].node_id} has already been prepared or executed and cannot be linked to') + self.graph.add_edge(edge) + + def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None: + if not self._is_node_updatable(edge[1].node_id): + raise NodeAlreadyExecutedError(f'Destination node {edge[1].node_id} has already been prepared or executed and cannot have a source edge deleted') + self.graph.delete_edge(edge) + +GraphInvocation.update_forward_refs() diff --git a/ldm/invoke/app/services/image_storage.py b/ldm/invoke/app/services/image_storage.py new file mode 100644 index 0000000000..03227d870b --- /dev/null +++ b/ldm/invoke/app/services/image_storage.py @@ -0,0 +1,104 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from abc import ABC, abstractmethod +from enum import Enum +import datetime +import os +from pathlib import Path +from queue import Queue +from typing import Dict +from PIL.Image import Image +from ...pngwriter import PngWriter + + +class ImageType(str, Enum): + RESULT = 'results' + INTERMEDIATE = 'intermediates' + UPLOAD = 'uploads' + + +class ImageStorageBase(ABC): + """Responsible for storing and retrieving images.""" + + @abstractmethod + def get(self, image_type: ImageType, image_name: str) -> Image: + pass + + # TODO: make this a bit more flexible for e.g. cloud storage + @abstractmethod + def get_path(self, image_type: ImageType, image_name: str) -> str: + pass + + @abstractmethod + def save(self, image_type: ImageType, image_name: str, image: Image) -> None: + pass + + @abstractmethod + def delete(self, image_type: ImageType, image_name: str) -> None: + pass + + def create_name(self, context_id: str, node_id: str) -> str: + return f'{context_id}_{node_id}_{str(int(datetime.datetime.now(datetime.timezone.utc).timestamp()))}.png' + + +class DiskImageStorage(ImageStorageBase): + """Stores images on disk""" + __output_folder: str + __pngWriter: PngWriter + __cache_ids: Queue # TODO: this is an incredibly naive cache + __cache: Dict[str, Image] + __max_cache_size: int + + def __init__(self, output_folder: str): + self.__output_folder = output_folder + self.__pngWriter = PngWriter(output_folder) + self.__cache = dict() + self.__cache_ids = Queue() + self.__max_cache_size = 10 # TODO: get this from config + + Path(output_folder).mkdir(parents=True, exist_ok=True) + + # TODO: don't hard-code. get/save/delete should maybe take subpath? + for image_type in ImageType: + Path(os.path.join(output_folder, image_type)).mkdir(parents=True, exist_ok=True) + + def get(self, image_type: ImageType, image_name: str) -> Image: + image_path = self.get_path(image_type, image_name) + cache_item = self.__get_cache(image_path) + if cache_item: + return cache_item + + image = Image.open(image_path) + self.__set_cache(image_path, image) + return image + + # TODO: make this a bit more flexible for e.g. cloud storage + def get_path(self, image_type: ImageType, image_name: str) -> str: + path = os.path.join(self.__output_folder, image_type, image_name) + return path + + def save(self, image_type: ImageType, image_name: str, image: Image) -> None: + image_subpath = os.path.join(image_type, image_name) + self.__pngWriter.save_image_and_prompt_to_png(image, "", image_subpath, None) # TODO: just pass full path to png writer + + image_path = self.get_path(image_type, image_name) + self.__set_cache(image_path, image) + + def delete(self, image_type: ImageType, image_name: str) -> None: + image_path = self.get_path(image_type, image_name) + if os.path.exists(image_path): + os.remove(image_path) + + if image_path in self.__cache: + del self.__cache[image_path] + + def __get_cache(self, image_name: str) -> Image: + return None if image_name not in self.__cache else self.__cache[image_name] + + def __set_cache(self, image_name: str, image: Image): + if not image_name in self.__cache: + self.__cache[image_name] = image + self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache + if len(self.__cache) > self.__max_cache_size: + cache_id = self.__cache_ids.get() + del self.__cache[cache_id] diff --git a/ldm/invoke/app/services/invocation_queue.py b/ldm/invoke/app/services/invocation_queue.py new file mode 100644 index 0000000000..0a5b5ae3bb --- /dev/null +++ b/ldm/invoke/app/services/invocation_queue.py @@ -0,0 +1,46 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from abc import ABC, abstractmethod +from queue import Queue + + +# TODO: make this serializable +class InvocationQueueItem: + #session_id: str + graph_execution_state_id: str + invocation_id: str + invoke_all: bool + + def __init__(self, + #session_id: str, + graph_execution_state_id: str, + invocation_id: str, + invoke_all: bool = False): + #self.session_id = session_id + self.graph_execution_state_id = graph_execution_state_id + self.invocation_id = invocation_id + self.invoke_all = invoke_all + + +class InvocationQueueABC(ABC): + """Abstract base class for all invocation queues""" + @abstractmethod + def get(self) -> InvocationQueueItem: + pass + + @abstractmethod + def put(self, item: InvocationQueueItem|None) -> None: + pass + + +class MemoryInvocationQueue(InvocationQueueABC): + __queue: Queue + + def __init__(self): + self.__queue = Queue() + + def get(self) -> InvocationQueueItem: + return self.__queue.get() + + def put(self, item: InvocationQueueItem|None) -> None: + self.__queue.put(item) diff --git a/ldm/invoke/app/services/invocation_services.py b/ldm/invoke/app/services/invocation_services.py new file mode 100644 index 0000000000..9eb5309d3d --- /dev/null +++ b/ldm/invoke/app/services/invocation_services.py @@ -0,0 +1,20 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) +from .image_storage import ImageStorageBase +from .events import EventServiceBase +from ....generate import Generate + + +class InvocationServices(): + """Services that can be used by invocations""" + generate: Generate # TODO: wrap Generate, or split it up from model? + events: EventServiceBase + images: ImageStorageBase + + def __init__(self, + generate: Generate, + events: EventServiceBase, + images: ImageStorageBase + ): + self.generate = generate + self.events = events + self.images = images diff --git a/ldm/invoke/app/services/invoker.py b/ldm/invoke/app/services/invoker.py new file mode 100644 index 0000000000..796f541781 --- /dev/null +++ b/ldm/invoke/app/services/invoker.py @@ -0,0 +1,109 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +from abc import ABC +from threading import Event, Thread +from .graph import Graph, GraphExecutionState +from .item_storage import ItemStorageABC +from ..invocations.baseinvocation import InvocationContext +from .invocation_services import InvocationServices +from .invocation_queue import InvocationQueueABC, InvocationQueueItem + + +class InvokerServices: + """Services used by the Invoker for execution""" + + queue: InvocationQueueABC + graph_execution_manager: ItemStorageABC[GraphExecutionState] + processor: 'InvocationProcessorABC' + + def __init__(self, + queue: InvocationQueueABC, + graph_execution_manager: ItemStorageABC[GraphExecutionState], + processor: 'InvocationProcessorABC'): + self.queue = queue + self.graph_execution_manager = graph_execution_manager + self.processor = processor + + +class Invoker: + """The invoker, used to execute invocations""" + + services: InvocationServices + invoker_services: InvokerServices + + def __init__(self, + services: InvocationServices, # Services used by nodes to perform invocations + invoker_services: InvokerServices # Services used by the invoker for orchestration + ): + self.services = services + self.invoker_services = invoker_services + self._start() + + + def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> str|None: + """Determines the next node to invoke and returns the id of the invoked node, or None if there are no nodes to execute""" + + # Get the next invocation + invocation = graph_execution_state.next() + if not invocation: + return None + + # Save the execution state + self.invoker_services.graph_execution_manager.set(graph_execution_state) + + # Queue the invocation + print(f'queueing item {invocation.id}') + self.invoker_services.queue.put(InvocationQueueItem( + #session_id = session.id, + graph_execution_state_id = graph_execution_state.id, + invocation_id = invocation.id, + invoke_all = invoke_all + )) + + return invocation.id + + + def create_execution_state(self, graph: Graph|None = None) -> GraphExecutionState: + """Creates a new execution state for the given graph""" + new_state = GraphExecutionState(graph = Graph() if graph is None else graph) + self.invoker_services.graph_execution_manager.set(new_state) + return new_state + + + def __start_service(self, service) -> None: + # Call start() method on any services that have it + start_op = getattr(service, 'start', None) + if callable(start_op): + start_op(self) + + + def __stop_service(self, service) -> None: + # Call stop() method on any services that have it + stop_op = getattr(service, 'stop', None) + if callable(stop_op): + stop_op(self) + + + def _start(self) -> None: + """Starts the invoker. This is called automatically when the invoker is created.""" + for service in vars(self.invoker_services): + self.__start_service(getattr(self.invoker_services, service)) + + for service in vars(self.services): + self.__start_service(getattr(self.services, service)) + + + def stop(self) -> None: + """Stops the invoker. A new invoker will have to be created to execute further.""" + # First stop all services + for service in vars(self.services): + self.__stop_service(getattr(self.services, service)) + + for service in vars(self.invoker_services): + self.__stop_service(getattr(self.invoker_services, service)) + + self.invoker_services.queue.put(None) + + +class InvocationProcessorABC(ABC): + pass \ No newline at end of file diff --git a/ldm/invoke/app/services/item_storage.py b/ldm/invoke/app/services/item_storage.py new file mode 100644 index 0000000000..738f06cb7e --- /dev/null +++ b/ldm/invoke/app/services/item_storage.py @@ -0,0 +1,57 @@ + +from typing import Callable, TypeVar, Generic +from pydantic import BaseModel, Field +from pydantic.generics import GenericModel +from abc import ABC, abstractmethod + +T = TypeVar('T', bound=BaseModel) + +class PaginatedResults(GenericModel, Generic[T]): + """Paginated results""" + items: list[T] = Field(description = "Items") + page: int = Field(description = "Current Page") + pages: int = Field(description = "Total number of pages") + per_page: int = Field(description = "Number of items per page") + total: int = Field(description = "Total number of items in result") + + +class ItemStorageABC(ABC, Generic[T]): + _on_changed_callbacks: list[Callable[[T], None]] + _on_deleted_callbacks: list[Callable[[str], None]] + + def __init__(self) -> None: + self._on_changed_callbacks = list() + self._on_deleted_callbacks = list() + + """Base item storage class""" + @abstractmethod + def get(self, item_id: str) -> T: + pass + + @abstractmethod + def set(self, item: T) -> None: + pass + + @abstractmethod + def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: + pass + + @abstractmethod + def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: + pass + + def on_changed(self, on_changed: Callable[[T], None]) -> None: + """Register a callback for when an item is changed""" + self._on_changed_callbacks.append(on_changed) + + def on_deleted(self, on_deleted: Callable[[str], None]) -> None: + """Register a callback for when an item is deleted""" + self._on_deleted_callbacks.append(on_deleted) + + def _on_changed(self, item: T) -> None: + for callback in self._on_changed_callbacks: + callback(item) + + def _on_deleted(self, item_id: str) -> None: + for callback in self._on_deleted_callbacks: + callback(item_id) diff --git a/ldm/invoke/app/services/processor.py b/ldm/invoke/app/services/processor.py new file mode 100644 index 0000000000..9b51a6bcbc --- /dev/null +++ b/ldm/invoke/app/services/processor.py @@ -0,0 +1,78 @@ +from threading import Event, Thread +from ..invocations.baseinvocation import InvocationContext +from .invocation_queue import InvocationQueueItem +from .invoker import InvocationProcessorABC, Invoker + + +class DefaultInvocationProcessor(InvocationProcessorABC): + __invoker_thread: Thread + __stop_event: Event + __invoker: Invoker + + def start(self, invoker) -> None: + self.__invoker = invoker + self.__stop_event = Event() + self.__invoker_thread = Thread( + name = "invoker_processor", + target = self.__process, + kwargs = dict(stop_event = self.__stop_event) + ) + self.__invoker_thread.daemon = True # TODO: probably better to just not use threads? + self.__invoker_thread.start() + + + def stop(self, *args, **kwargs) -> None: + self.__stop_event.set() + + + def __process(self, stop_event: Event): + try: + while not stop_event.is_set(): + queue_item: InvocationQueueItem = self.__invoker.invoker_services.queue.get() + if not queue_item: # Probably stopping + continue + + graph_execution_state = self.__invoker.invoker_services.graph_execution_manager.get(queue_item.graph_execution_state_id) + invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id) + + # Send starting event + self.__invoker.services.events.emit_invocation_started( + graph_execution_state_id = graph_execution_state.id, + invocation_id = invocation.id + ) + + # Invoke + try: + outputs = invocation.invoke(InvocationContext( + services = self.__invoker.services, + graph_execution_state_id = graph_execution_state.id + )) + + # Save outputs and history + graph_execution_state.complete(invocation.id, outputs) + + # Save the state changes + self.__invoker.invoker_services.graph_execution_manager.set(graph_execution_state) + + # Send complete event + self.__invoker.services.events.emit_invocation_complete( + graph_execution_state_id = graph_execution_state.id, + invocation_id = invocation.id, + result = outputs.dict() + ) + + # Queue any further commands if invoking all + is_complete = graph_execution_state.is_complete() + if queue_item.invoke_all and not is_complete: + self.__invoker.invoke(graph_execution_state, invoke_all = True) + elif is_complete: + self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id) + except KeyboardInterrupt: + pass + except Exception as e: + # TODO: Log the error, mark the invocation as failed, and emit an event + print(f'Error invoking {invocation.id}: {e}') + pass + + except KeyboardInterrupt: + ... # Log something? diff --git a/ldm/invoke/app/services/sqlite.py b/ldm/invoke/app/services/sqlite.py new file mode 100644 index 0000000000..8858bbd874 --- /dev/null +++ b/ldm/invoke/app/services/sqlite.py @@ -0,0 +1,119 @@ +import sqlite3 +from threading import Lock +from typing import Generic, TypeVar, Union, get_args +from pydantic import BaseModel, parse_raw_as +from .item_storage import ItemStorageABC, PaginatedResults + +T = TypeVar('T', bound=BaseModel) + +sqlite_memory = ':memory:' + +class SqliteItemStorage(ItemStorageABC, Generic[T]): + _filename: str + _table_name: str + _conn: sqlite3.Connection + _cursor: sqlite3.Cursor + _id_field: str + _lock: Lock + + def __init__(self, filename: str, table_name: str, id_field: str = 'id'): + super().__init__() + + self._filename = filename + self._table_name = table_name + self._id_field = id_field # TODO: validate that T has this field + self._lock = Lock() + + self._conn = sqlite3.connect(self._filename, check_same_thread=False) # TODO: figure out a better threading solution + self._cursor = self._conn.cursor() + + self._create_table() + + def _create_table(self): + try: + self._lock.acquire() + self._cursor.execute(f'''CREATE TABLE IF NOT EXISTS {self._table_name} ( + item TEXT, + id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);''') + self._cursor.execute(f'''CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);''') + finally: + self._lock.release() + + def _parse_item(self, item: str) -> T: + item_type = get_args(self.__orig_class__)[0] + return parse_raw_as(item_type, item) + + def set(self, item: T): + try: + self._lock.acquire() + self._cursor.execute(f'''INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);''', (item.json(),)) + finally: + self._lock.release() + self._on_changed(item) + + def get(self, id: str) -> Union[T, None]: + try: + self._lock.acquire() + self._cursor.execute(f'''SELECT item FROM {self._table_name} WHERE id = ?;''', (str(id),)) + result = self._cursor.fetchone() + finally: + self._lock.release() + + if not result: + return None + + return self._parse_item(result[0]) + + def delete(self, id: str): + try: + self._lock.acquire() + self._cursor.execute(f'''DELETE FROM {self._table_name} WHERE id = ?;''', (str(id),)) + finally: + self._lock.release() + self._on_deleted(id) + + def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: + try: + self._lock.acquire() + self._cursor.execute(f'''SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;''', (per_page, page * per_page)) + result = self._cursor.fetchall() + + items = list(map(lambda r: self._parse_item(r[0]), result)) + + self._cursor.execute(f'''SELECT count(*) FROM {self._table_name};''') + count = self._cursor.fetchone()[0] + finally: + self._lock.release() + + pageCount = int(count / per_page) + 1 + + return PaginatedResults[T]( + items = items, + page = page, + pages = pageCount, + per_page = per_page, + total = count + ) + + def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: + try: + self._lock.acquire() + self._cursor.execute(f'''SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;''', (f'%{query}%', per_page, page * per_page)) + result = self._cursor.fetchall() + + items = list(map(lambda r: self._parse_item(r[0]), result)) + + self._cursor.execute(f'''SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;''', (f'%{query}%',)) + count = self._cursor.fetchone()[0] + finally: + self._lock.release() + + pageCount = int(count / per_page) + 1 + + return PaginatedResults[T]( + items = items, + page = page, + pages = pageCount, + per_page = per_page, + total = count + ) diff --git a/pyproject.toml b/pyproject.toml index 6357d25653..3d50bd124d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,9 @@ dependencies = [ "einops", "eventlet", "facexlib", + "fastapi==0.85.0", + "fastapi-events==0.6.0", + "fastapi-socketio==0.0.9", "flask==2.1.3", "flask_cors==3.0.10", "flask_socketio==5.3.0", @@ -66,6 +69,7 @@ dependencies = [ "prompt-toolkit", "pypatchmatch", "pyreadline3", + "python-multipart==0.0.5", "pytorch-lightning==1.7.7", "realesrgan", "requests==2.28.2", @@ -80,6 +84,7 @@ dependencies = [ "torchvision>=0.14.1", "torchmetrics", "transformers~=4.25", + "uvicorn[standard]==0.20.0", "windows-curses; sys_platform=='win32'", ] diff --git a/scripts/invoke-new.py b/scripts/invoke-new.py new file mode 100644 index 0000000000..2bc5330a5c --- /dev/null +++ b/scripts/invoke-new.py @@ -0,0 +1,20 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +import os +import sys + +def main(): + # Change working directory to the repo root + os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + + if '--web' in sys.argv: + from ldm.invoke.app.api_app import invoke_api + invoke_api() + else: + # TODO: Parse some top-level args here. + from ldm.invoke.app.cli_app import invoke_cli + invoke_cli() + + +if __name__ == '__main__': + main() diff --git a/static/dream_web/test.html b/static/dream_web/test.html new file mode 100644 index 0000000000..e99abb3703 --- /dev/null +++ b/static/dream_web/test.html @@ -0,0 +1,206 @@ + + + + InvokeAI Test + + + + + + + + + + + + + + + +
+ +
+ + + + + \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/nodes/__init__.py b/tests/nodes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py new file mode 100644 index 0000000000..0a5dcc7734 --- /dev/null +++ b/tests/nodes/test_graph_execution_state.py @@ -0,0 +1,114 @@ +from .test_invoker import create_edge +from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation +from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext +from ldm.invoke.app.services.invocation_services import InvocationServices +from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState +from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation +from ldm.invoke.app.invocations.upscale import UpscaleInvocation +import pytest + + +@pytest.fixture +def simple_graph(): + g = Graph() + g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi")) + g.add_node(ImageTestInvocation(id = "2")) + g.add_edge(create_edge("1", "prompt", "2", "prompt")) + return g + +@pytest.fixture +def mock_services(): + # NOTE: none of these are actually called by the test invocations + return InvocationServices(generate = None, events = None, images = None) + +def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]: + n = g.next() + if n is None: + return (None, None) + + print(f'invoking {n.id}: {type(n)}') + o = n.invoke(InvocationContext(services, "1")) + g.complete(n.id, o) + + return (n, o) + +def test_graph_state_executes_in_order(simple_graph, mock_services): + g = GraphExecutionState(graph = simple_graph) + + n1 = invoke_next(g, mock_services) + n2 = invoke_next(g, mock_services) + n3 = g.next() + + assert g.prepared_source_mapping[n1[0].id] == "1" + assert g.prepared_source_mapping[n2[0].id] == "2" + assert n3 is None + assert g.results[n1[0].id].prompt == n1[0].prompt + assert n2[0].prompt == n1[0].prompt + +def test_graph_is_complete(simple_graph, mock_services): + g = GraphExecutionState(graph = simple_graph) + n1 = invoke_next(g, mock_services) + n2 = invoke_next(g, mock_services) + n3 = g.next() + + assert g.is_complete() + +def test_graph_is_not_complete(simple_graph, mock_services): + g = GraphExecutionState(graph = simple_graph) + n1 = invoke_next(g, mock_services) + n2 = g.next() + + assert not g.is_complete() + +# TODO: test completion with iterators/subgraphs + +def test_graph_state_expands_iterator(mock_services): + graph = Graph() + test_prompts = ["Banana sushi", "Cat sushi"] + graph.add_node(PromptCollectionTestInvocation(id = "1", collection = list(test_prompts))) + graph.add_node(IterateInvocation(id = "2")) + graph.add_node(ImageTestInvocation(id = "3")) + graph.add_edge(create_edge("1", "collection", "2", "collection")) + graph.add_edge(create_edge("2", "item", "3", "prompt")) + + g = GraphExecutionState(graph = graph) + n1 = invoke_next(g, mock_services) + n2 = invoke_next(g, mock_services) + n3 = invoke_next(g, mock_services) + n4 = invoke_next(g, mock_services) + n5 = invoke_next(g, mock_services) + + assert g.prepared_source_mapping[n1[0].id] == "1" + assert g.prepared_source_mapping[n2[0].id] == "2" + assert g.prepared_source_mapping[n3[0].id] == "2" + assert g.prepared_source_mapping[n4[0].id] == "3" + assert g.prepared_source_mapping[n5[0].id] == "3" + + assert isinstance(n4[0], ImageTestInvocation) + assert isinstance(n5[0], ImageTestInvocation) + + prompts = [n4[0].prompt, n5[0].prompt] + assert sorted(prompts) == sorted(test_prompts) + +def test_graph_state_collects(mock_services): + graph = Graph() + test_prompts = ["Banana sushi", "Cat sushi"] + graph.add_node(PromptCollectionTestInvocation(id = "1", collection = list(test_prompts))) + graph.add_node(IterateInvocation(id = "2")) + graph.add_node(PromptTestInvocation(id = "3")) + graph.add_node(CollectInvocation(id = "4")) + graph.add_edge(create_edge("1", "collection", "2", "collection")) + graph.add_edge(create_edge("2", "item", "3", "prompt")) + graph.add_edge(create_edge("3", "prompt", "4", "item")) + + g = GraphExecutionState(graph = graph) + n1 = invoke_next(g, mock_services) + n2 = invoke_next(g, mock_services) + n3 = invoke_next(g, mock_services) + n4 = invoke_next(g, mock_services) + n5 = invoke_next(g, mock_services) + n6 = invoke_next(g, mock_services) + + assert isinstance(n6[0], CollectInvocation) + + assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts) diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py new file mode 100644 index 0000000000..a6d96f61c0 --- /dev/null +++ b/tests/nodes/test_invoker.py @@ -0,0 +1,85 @@ +from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation, TestEventService, create_edge, wait_until +from ldm.invoke.app.services.processor import DefaultInvocationProcessor +from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory +from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue +from ldm.invoke.app.services.invoker import Invoker, InvokerServices +from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext +from ldm.invoke.app.services.invocation_services import InvocationServices +from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState +from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation +from ldm.invoke.app.invocations.upscale import UpscaleInvocation +import pytest + + +@pytest.fixture +def simple_graph(): + g = Graph() + g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi")) + g.add_node(ImageTestInvocation(id = "2")) + g.add_edge(create_edge("1", "prompt", "2", "prompt")) + return g + +@pytest.fixture +def mock_services() -> InvocationServices: + # NOTE: none of these are actually called by the test invocations + return InvocationServices(generate = None, events = TestEventService(), images = None) + +@pytest.fixture() +def mock_invoker_services() -> InvokerServices: + return InvokerServices( + queue = MemoryInvocationQueue(), + graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), + processor = DefaultInvocationProcessor() + ) + +@pytest.fixture() +def mock_invoker(mock_services: InvocationServices, mock_invoker_services: InvokerServices) -> Invoker: + return Invoker( + services = mock_services, + invoker_services = mock_invoker_services + ) + +def test_can_create_graph_state(mock_invoker: Invoker): + g = mock_invoker.create_execution_state() + mock_invoker.stop() + + assert g is not None + assert isinstance(g, GraphExecutionState) + +def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph): + g = mock_invoker.create_execution_state(graph = simple_graph) + mock_invoker.stop() + + assert g is not None + assert isinstance(g, GraphExecutionState) + assert g.graph == simple_graph + +def test_can_invoke(mock_invoker: Invoker, simple_graph): + g = mock_invoker.create_execution_state(graph = simple_graph) + invocation_id = mock_invoker.invoke(g) + assert invocation_id is not None + + def has_executed_any(g: GraphExecutionState): + g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) + return len(g.executed) > 0 + + wait_until(lambda: has_executed_any(g), timeout = 5, interval = 1) + mock_invoker.stop() + + g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) + assert len(g.executed) > 0 + +def test_can_invoke_all(mock_invoker: Invoker, simple_graph): + g = mock_invoker.create_execution_state(graph = simple_graph) + invocation_id = mock_invoker.invoke(g, invoke_all = True) + assert invocation_id is not None + + def has_executed_all(g: GraphExecutionState): + g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) + return g.is_complete() + + wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1) + mock_invoker.stop() + + g = mock_invoker.invoker_services.graph_execution_manager.get(g.id) + assert g.is_complete() diff --git a/tests/nodes/test_node_graph.py b/tests/nodes/test_node_graph.py new file mode 100644 index 0000000000..1b5b341192 --- /dev/null +++ b/tests/nodes/test_node_graph.py @@ -0,0 +1,501 @@ +from ldm.invoke.app.invocations.image import * + +from .test_nodes import ListPassThroughInvocation, PromptTestInvocation +from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation +from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation +from ldm.invoke.app.invocations.upscale import UpscaleInvocation +import pytest + + +# Helpers +def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]: + return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field)) + +# Tests +def test_connections_are_compatible(): + from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi") + from_field = "image" + to_node = UpscaleInvocation(id = "2") + to_field = "image" + + result = are_connections_compatible(from_node, from_field, to_node, to_field) + + assert result == True + +def test_connections_are_incompatible(): + from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi") + from_field = "image" + to_node = UpscaleInvocation(id = "2") + to_field = "strength" + + result = are_connections_compatible(from_node, from_field, to_node, to_field) + + assert result == False + +def test_connections_incompatible_with_invalid_fields(): + from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi") + from_field = "invalid_field" + to_node = UpscaleInvocation(id = "2") + to_field = "image" + + # From field is invalid + result = are_connections_compatible(from_node, from_field, to_node, to_field) + assert result == False + + # To field is invalid + from_field = "image" + to_field = "invalid_field" + + result = are_connections_compatible(from_node, from_field, to_node, to_field) + assert result == False + +def test_graph_can_add_node(): + g = Graph() + n = TextToImageInvocation(id = "1", prompt = "Banana sushi") + g.add_node(n) + + assert n.id in g.nodes + +def test_graph_fails_to_add_node_with_duplicate_id(): + g = Graph() + n = TextToImageInvocation(id = "1", prompt = "Banana sushi") + g.add_node(n) + n2 = TextToImageInvocation(id = "1", prompt = "Banana sushi the second") + + with pytest.raises(NodeAlreadyInGraphError): + g.add_node(n2) + +def test_graph_updates_node(): + g = Graph() + n = TextToImageInvocation(id = "1", prompt = "Banana sushi") + g.add_node(n) + n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi the second") + g.add_node(n2) + + nu = TextToImageInvocation(id = "1", prompt = "Banana sushi updated") + + g.update_node("1", nu) + + assert g.nodes["1"].prompt == "Banana sushi updated" + +def test_graph_fails_to_update_node_if_type_changes(): + g = Graph() + n = TextToImageInvocation(id = "1", prompt = "Banana sushi") + g.add_node(n) + n2 = UpscaleInvocation(id = "2") + g.add_node(n2) + + nu = UpscaleInvocation(id = "1") + + with pytest.raises(TypeError): + g.update_node("1", nu) + +def test_graph_allows_non_conflicting_id_change(): + g = Graph() + n = TextToImageInvocation(id = "1", prompt = "Banana sushi") + g.add_node(n) + n2 = UpscaleInvocation(id = "2") + g.add_node(n2) + e1 = create_edge(n.id,"image",n2.id,"image") + g.add_edge(e1) + + nu = TextToImageInvocation(id = "3", prompt = "Banana sushi") + g.update_node("1", nu) + + with pytest.raises(NodeNotFoundError): + g.get_node("1") + + assert g.get_node("3").prompt == "Banana sushi" + + assert len(g.edges) == 1 + assert (EdgeConnection(node_id = "3", field = "image"), EdgeConnection(node_id = "2", field = "image")) in g.edges + +def test_graph_fails_to_update_node_id_if_conflict(): + g = Graph() + n = TextToImageInvocation(id = "1", prompt = "Banana sushi") + g.add_node(n) + n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi the second") + g.add_node(n2) + + nu = TextToImageInvocation(id = "2", prompt = "Banana sushi") + with pytest.raises(NodeAlreadyInGraphError): + g.update_node("1", nu) + +def test_graph_adds_edge(): + g = Graph() + n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n2 = UpscaleInvocation(id = "2") + g.add_node(n1) + g.add_node(n2) + e = create_edge(n1.id,"image",n2.id,"image") + + g.add_edge(e) + + assert e in g.edges + +def test_graph_fails_to_add_edge_with_cycle(): + g = Graph() + n1 = UpscaleInvocation(id = "1") + g.add_node(n1) + e = create_edge(n1.id,"image",n1.id,"image") + with pytest.raises(InvalidEdgeError): + g.add_edge(e) + +def test_graph_fails_to_add_edge_with_long_cycle(): + g = Graph() + n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n2 = UpscaleInvocation(id = "2") + n3 = UpscaleInvocation(id = "3") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + e1 = create_edge(n1.id,"image",n2.id,"image") + e2 = create_edge(n2.id,"image",n3.id,"image") + e3 = create_edge(n3.id,"image",n2.id,"image") + g.add_edge(e1) + g.add_edge(e2) + with pytest.raises(InvalidEdgeError): + g.add_edge(e3) + +def test_graph_fails_to_add_edge_with_missing_node_id(): + g = Graph() + n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n2 = UpscaleInvocation(id = "2") + g.add_node(n1) + g.add_node(n2) + e1 = create_edge("1","image","3","image") + e2 = create_edge("3","image","1","image") + with pytest.raises(InvalidEdgeError): + g.add_edge(e1) + with pytest.raises(InvalidEdgeError): + g.add_edge(e2) + +def test_graph_fails_to_add_edge_when_destination_exists(): + g = Graph() + n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n2 = UpscaleInvocation(id = "2") + n3 = UpscaleInvocation(id = "3") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + e1 = create_edge(n1.id,"image",n2.id,"image") + e2 = create_edge(n1.id,"image",n3.id,"image") + e3 = create_edge(n2.id,"image",n3.id,"image") + g.add_edge(e1) + g.add_edge(e2) + with pytest.raises(InvalidEdgeError): + g.add_edge(e3) + + +def test_graph_fails_to_add_edge_with_mismatched_types(): + g = Graph() + n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n2 = UpscaleInvocation(id = "2") + g.add_node(n1) + g.add_node(n2) + e1 = create_edge("1","image","2","strength") + with pytest.raises(InvalidEdgeError): + g.add_edge(e1) + +def test_graph_connects_collector(): + g = Graph() + n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi 2") + n3 = CollectInvocation(id = "3") + n4 = ListPassThroughInvocation(id = "4") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + g.add_node(n4) + + e1 = create_edge("1","image","3","item") + e2 = create_edge("2","image","3","item") + e3 = create_edge("3","collection","4","collection") + g.add_edge(e1) + g.add_edge(e2) + g.add_edge(e3) + +# TODO: test that derived types mixed with base types are compatible + +def test_graph_collector_invalid_with_varying_input_types(): + g = Graph() + n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n2 = PromptTestInvocation(id = "2", prompt = "banana sushi 2") + n3 = CollectInvocation(id = "3") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + + e1 = create_edge("1","image","3","item") + e2 = create_edge("2","prompt","3","item") + g.add_edge(e1) + + with pytest.raises(InvalidEdgeError): + g.add_edge(e2) + +def test_graph_collector_invalid_with_varying_input_output(): + g = Graph() + n1 = PromptTestInvocation(id = "1", prompt = "Banana sushi") + n2 = PromptTestInvocation(id = "2", prompt = "Banana sushi 2") + n3 = CollectInvocation(id = "3") + n4 = ListPassThroughInvocation(id = "4") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + g.add_node(n4) + + e1 = create_edge("1","prompt","3","item") + e2 = create_edge("2","prompt","3","item") + e3 = create_edge("3","collection","4","collection") + g.add_edge(e1) + g.add_edge(e2) + + with pytest.raises(InvalidEdgeError): + g.add_edge(e3) + +def test_graph_collector_invalid_with_non_list_output(): + g = Graph() + n1 = PromptTestInvocation(id = "1", prompt = "Banana sushi") + n2 = PromptTestInvocation(id = "2", prompt = "Banana sushi 2") + n3 = CollectInvocation(id = "3") + n4 = PromptTestInvocation(id = "4") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + g.add_node(n4) + + e1 = create_edge("1","prompt","3","item") + e2 = create_edge("2","prompt","3","item") + e3 = create_edge("3","collection","4","prompt") + g.add_edge(e1) + g.add_edge(e2) + + with pytest.raises(InvalidEdgeError): + g.add_edge(e3) + +def test_graph_connects_iterator(): + g = Graph() + n1 = ListPassThroughInvocation(id = "1") + n2 = IterateInvocation(id = "2") + n3 = ImageToImageInvocation(id = "3", prompt = "Banana sushi") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + + e1 = create_edge("1","collection","2","collection") + e2 = create_edge("2","item","3","image") + g.add_edge(e1) + g.add_edge(e2) + +# TODO: TEST INVALID ITERATOR SCENARIOS + +def test_graph_iterator_invalid_if_multiple_inputs(): + g = Graph() + n1 = ListPassThroughInvocation(id = "1") + n2 = IterateInvocation(id = "2") + n3 = ImageToImageInvocation(id = "3", prompt = "Banana sushi") + n4 = ListPassThroughInvocation(id = "4") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + g.add_node(n4) + + e1 = create_edge("1","collection","2","collection") + e2 = create_edge("2","item","3","image") + e3 = create_edge("4","collection","2","collection") + g.add_edge(e1) + g.add_edge(e2) + + with pytest.raises(InvalidEdgeError): + g.add_edge(e3) + +def test_graph_iterator_invalid_if_input_not_list(): + g = Graph() + n1 = TextToImageInvocation(id = "1", promopt = "Banana sushi") + n2 = IterateInvocation(id = "2") + g.add_node(n1) + g.add_node(n2) + + e1 = create_edge("1","collection","2","collection") + + with pytest.raises(InvalidEdgeError): + g.add_edge(e1) + +def test_graph_iterator_invalid_if_output_and_input_types_different(): + g = Graph() + n1 = ListPassThroughInvocation(id = "1") + n2 = IterateInvocation(id = "2") + n3 = PromptTestInvocation(id = "3", prompt = "Banana sushi") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + + e1 = create_edge("1","collection","2","collection") + e2 = create_edge("2","item","3","prompt") + g.add_edge(e1) + + with pytest.raises(InvalidEdgeError): + g.add_edge(e2) + +def test_graph_validates(): + g = Graph() + n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n2 = UpscaleInvocation(id = "2") + g.add_node(n1) + g.add_node(n2) + e1 = create_edge("1","image","2","image") + g.add_edge(e1) + + assert g.is_valid() == True + +def test_graph_invalid_if_edges_reference_missing_nodes(): + g = Graph() + n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + g.nodes[n1.id] = n1 + e1 = create_edge("1","image","2","image") + g.edges.append(e1) + + assert g.is_valid() == False + +def test_graph_invalid_if_subgraph_invalid(): + g = Graph() + n1 = GraphInvocation(id = "1") + n1.graph = Graph() + + n1_1 = TextToImageInvocation(id = "2", prompt = "Banana sushi") + n1.graph.nodes[n1_1.id] = n1_1 + e1 = create_edge("1","image","2","image") + n1.graph.edges.append(e1) + + g.nodes[n1.id] = n1 + + assert g.is_valid() == False + +def test_graph_invalid_if_has_cycle(): + g = Graph() + n1 = UpscaleInvocation(id = "1") + n2 = UpscaleInvocation(id = "2") + g.nodes[n1.id] = n1 + g.nodes[n2.id] = n2 + e1 = create_edge("1","image","2","image") + e2 = create_edge("2","image","1","image") + g.edges.append(e1) + g.edges.append(e2) + + assert g.is_valid() == False + +def test_graph_invalid_with_invalid_connection(): + g = Graph() + n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n2 = UpscaleInvocation(id = "2") + g.nodes[n1.id] = n1 + g.nodes[n2.id] = n2 + e1 = create_edge("1","image","2","strength") + g.edges.append(e1) + + assert g.is_valid() == False + + +# TODO: Subgraph operations +def test_graph_gets_subgraph_node(): + g = Graph() + n1 = GraphInvocation(id = "1") + n1.graph = Graph() + n1.graph.add_node + + n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1.graph.add_node(n1_1) + + g.add_node(n1) + + result = g.get_node('1.1') + + assert result is not None + assert result.id == '1' + assert result == n1_1 + +def test_graph_fails_to_get_missing_subgraph_node(): + g = Graph() + n1 = GraphInvocation(id = "1") + n1.graph = Graph() + n1.graph.add_node + + n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1.graph.add_node(n1_1) + + g.add_node(n1) + + with pytest.raises(NodeNotFoundError): + result = g.get_node('1.2') + +def test_graph_fails_to_enumerate_non_subgraph_node(): + g = Graph() + n1 = GraphInvocation(id = "1") + n1.graph = Graph() + n1.graph.add_node + + n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1.graph.add_node(n1_1) + + g.add_node(n1) + + n2 = UpscaleInvocation(id = "2") + g.add_node(n2) + + with pytest.raises(NodeNotFoundError): + result = g.get_node('2.1') + +def test_graph_gets_networkx_graph(): + g = Graph() + n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n2 = UpscaleInvocation(id = "2") + g.add_node(n1) + g.add_node(n2) + e = create_edge(n1.id,"image",n2.id,"image") + g.add_edge(e) + + nxg = g.nx_graph() + + assert '1' in nxg.nodes + assert '2' in nxg.nodes + assert ('1','2') in nxg.edges + + +# TODO: Graph serializes and deserializes +def test_graph_can_serialize(): + g = Graph() + n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n2 = UpscaleInvocation(id = "2") + g.add_node(n1) + g.add_node(n2) + e = create_edge(n1.id,"image",n2.id,"image") + g.add_edge(e) + + # Not throwing on this line is sufficient + json = g.json() + +def test_graph_can_deserialize(): + g = Graph() + n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n2 = UpscaleInvocation(id = "2") + g.add_node(n1) + g.add_node(n2) + e = create_edge(n1.id,"image",n2.id,"image") + g.add_edge(e) + + json = g.json() + g2 = Graph.parse_raw(json) + + assert g2 is not None + assert g2.nodes['1'] is not None + assert g2.nodes['2'] is not None + assert len(g2.edges) == 1 + assert g2.edges[0][0].node_id == '1' + assert g2.edges[0][0].field == 'image' + assert g2.edges[0][1].node_id == '2' + assert g2.edges[0][1].field == 'image' + +def test_graph_can_generate_schema(): + # Not throwing on this line is sufficient + # NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation + schema = Graph.schema_json(indent = 2) diff --git a/tests/nodes/test_nodes.py b/tests/nodes/test_nodes.py new file mode 100644 index 0000000000..fea2e75e95 --- /dev/null +++ b/tests/nodes/test_nodes.py @@ -0,0 +1,92 @@ +from typing import Any, Callable, Literal +from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext +from ldm.invoke.app.invocations.image import ImageField +from ldm.invoke.app.services.invocation_services import InvocationServices +from pydantic import Field +import pytest + +# Define test invocations before importing anything that uses invocations +class ListPassThroughInvocationOutput(BaseInvocationOutput): + type: Literal['test_list_output'] = 'test_list_output' + + collection: list[ImageField] = Field(default_factory=list) + +class ListPassThroughInvocation(BaseInvocation): + type: Literal['test_list'] = 'test_list' + + collection: list[ImageField] = Field(default_factory=list) + + def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput: + return ListPassThroughInvocationOutput(collection = self.collection) + +class PromptTestInvocationOutput(BaseInvocationOutput): + type: Literal['test_prompt_output'] = 'test_prompt_output' + + prompt: str = Field(default = "") + +class PromptTestInvocation(BaseInvocation): + type: Literal['test_prompt'] = 'test_prompt' + + prompt: str = Field(default = "") + + def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput: + return PromptTestInvocationOutput(prompt = self.prompt) + +class ImageTestInvocationOutput(BaseInvocationOutput): + type: Literal['test_image_output'] = 'test_image_output' + + image: ImageField = Field() + +class ImageTestInvocation(BaseInvocation): + type: Literal['test_image'] = 'test_image' + + prompt: str = Field(default = "") + + def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: + return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) + +class PromptCollectionTestInvocationOutput(BaseInvocationOutput): + type: Literal['test_prompt_collection_output'] = 'test_prompt_collection_output' + collection: list[str] = Field(default_factory=list) + +class PromptCollectionTestInvocation(BaseInvocation): + type: Literal['test_prompt_collection'] = 'test_prompt_collection' + collection: list[str] = Field() + + def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: + return PromptCollectionTestInvocationOutput(collection=self.collection.copy()) + + +from ldm.invoke.app.services.events import EventServiceBase +from ldm.invoke.app.services.graph import EdgeConnection + +def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]: + return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field)) + + +class TestEvent: + event_name: str + payload: Any + + def __init__(self, event_name: str, payload: Any): + self.event_name = event_name + self.payload = payload + +class TestEventService(EventServiceBase): + events: list + + def __init__(self): + super().__init__() + self.events = list() + + def dispatch(self, event_name: str, payload: Any) -> None: + pass + +def wait_until(condition: Callable[[], bool], timeout: int = 10, interval: float = 0.1) -> None: + import time + start_time = time.time() + while time.time() - start_time < timeout: + if condition(): + return + time.sleep(interval) + raise TimeoutError("Condition not met") \ No newline at end of file diff --git a/tests/nodes/test_sqlite.py b/tests/nodes/test_sqlite.py new file mode 100644 index 0000000000..e499bbce12 --- /dev/null +++ b/tests/nodes/test_sqlite.py @@ -0,0 +1,112 @@ +from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory +from pydantic import BaseModel, Field + + +class TestModel(BaseModel): + id: str = Field(description = "ID") + name: str = Field(description = "Name") + + +def test_sqlite_service_can_create_and_get(): + db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') + db.set(TestModel(id = '1', name = 'Test')) + assert db.get('1') == TestModel(id = '1', name = 'Test') + +def test_sqlite_service_can_list(): + db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') + db.set(TestModel(id = '1', name = 'Test')) + db.set(TestModel(id = '2', name = 'Test')) + db.set(TestModel(id = '3', name = 'Test')) + results = db.list() + assert results.page == 0 + assert results.pages == 1 + assert results.per_page == 10 + assert results.total == 3 + assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test'), TestModel(id = '3', name = 'Test')] + +def test_sqlite_service_can_delete(): + db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') + db.set(TestModel(id = '1', name = 'Test')) + db.delete('1') + assert db.get('1') is None + +def test_sqlite_service_calls_set_callback(): + db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') + called = False + def on_changed(item: TestModel): + nonlocal called + called = True + db.on_changed(on_changed) + db.set(TestModel(id = '1', name = 'Test')) + assert called + +def test_sqlite_service_calls_delete_callback(): + db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') + called = False + def on_deleted(item_id: str): + nonlocal called + called = True + db.on_deleted(on_deleted) + db.set(TestModel(id = '1', name = 'Test')) + db.delete('1') + assert called + +def test_sqlite_service_can_list_with_pagination(): + db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') + db.set(TestModel(id = '1', name = 'Test')) + db.set(TestModel(id = '2', name = 'Test')) + db.set(TestModel(id = '3', name = 'Test')) + results = db.list(page = 0, per_page = 2) + assert results.page == 0 + assert results.pages == 2 + assert results.per_page == 2 + assert results.total == 3 + assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test')] + +def test_sqlite_service_can_list_with_pagination_and_offset(): + db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') + db.set(TestModel(id = '1', name = 'Test')) + db.set(TestModel(id = '2', name = 'Test')) + db.set(TestModel(id = '3', name = 'Test')) + results = db.list(page = 1, per_page = 2) + assert results.page == 1 + assert results.pages == 2 + assert results.per_page == 2 + assert results.total == 3 + assert results.items == [TestModel(id = '3', name = 'Test')] + +def test_sqlite_service_can_search(): + db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') + db.set(TestModel(id = '1', name = 'Test')) + db.set(TestModel(id = '2', name = 'Test')) + db.set(TestModel(id = '3', name = 'Test')) + results = db.search(query = 'Test') + assert results.page == 0 + assert results.pages == 1 + assert results.per_page == 10 + assert results.total == 3 + assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test'), TestModel(id = '3', name = 'Test')] + +def test_sqlite_service_can_search_with_pagination(): + db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') + db.set(TestModel(id = '1', name = 'Test')) + db.set(TestModel(id = '2', name = 'Test')) + db.set(TestModel(id = '3', name = 'Test')) + results = db.search(query = 'Test', page = 0, per_page = 2) + assert results.page == 0 + assert results.pages == 2 + assert results.per_page == 2 + assert results.total == 3 + assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test')] + +def test_sqlite_service_can_search_with_pagination_and_offset(): + db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') + db.set(TestModel(id = '1', name = 'Test')) + db.set(TestModel(id = '2', name = 'Test')) + db.set(TestModel(id = '3', name = 'Test')) + results = db.search(query = 'Test', page = 1, per_page = 2) + assert results.page == 1 + assert results.pages == 2 + assert results.per_page == 2 + assert results.total == 3 + assert results.items == [TestModel(id = '3', name = 'Test')]