author Kyle Schouviller <kyle0654@hotmail.com> 1669872800 -0800
committer Kyle Schouviller <kyle0654@hotmail.com> 1676240900 -0800

Adding base node architecture

Fix type annotation errors

Runs and generates, but breaks in saving session

Fix default model value setting. Fix deprecation warning.

Fixed node api

Adding markdown docs

Simplifying Generate construction in apps

[nodes] A few minor changes (#2510)

* Pin api-related requirements

* Remove confusing extra CORS origins list

* Adds response models for HTTP 200

[nodes] Adding graph_execution_state to soon replace session. Adding tests with pytest.

Minor typing fixes

[nodes] Fix some small output query hookups

[node] Fixing some additional typing issues

[nodes] Move and expand graph code. Add base item storage and sqlite implementation.

Update startup to match new code

[nodes] Add callbacks to item storage

[nodes] Adding an InvocationContext object to use for invocations to provide easier extensibility

[nodes] New execution model that handles iteration

[nodes] Fixing the CLI

[nodes] Adding a note to the CLI

[nodes] Split processing thread into separate service

[node] Add error message on node processing failure

Removing old files and duplicated packages

Adding python-multipart
This commit is contained in:
Kyle Schouviller 2022-11-30 21:33:20 -08:00
parent 49ffb64ef3
commit 34e3aa1f88
42 changed files with 4510 additions and 0 deletions

6
.coveragerc Normal file
View File

@ -0,0 +1,6 @@
[run]
omit='.env/*'
source='.'
[report]
show_missing = true

1
.gitignore vendored
View File

@ -68,6 +68,7 @@ htmlcov/
.cache
nosetests.xml
coverage.xml
cov.xml
*.cover
*.py,cover
.hypothesis/

5
.pytest.ini Normal file
View File

@ -0,0 +1,5 @@
[pytest]
DJANGO_SETTINGS_MODULE = webtas.settings
; python_files = tests.py test_*.py *_tests.py
addopts = --cov=. --cov-config=.coveragerc --cov-report xml:cov.xml

View File

@ -0,0 +1,93 @@
# Invoke.AI Architecture
```mermaid
flowchart TB
subgraph apps[Applications]
webui[WebUI]
cli[CLI]
subgraph webapi[Web API]
api[HTTP API]
sio[Socket.IO]
end
end
subgraph invoke[Invoke]
direction LR
invoker
services
sessions
invocations
end
subgraph core[AI Core]
Generate
end
webui --> webapi
webapi --> invoke
cli --> invoke
invoker --> services & sessions
invocations --> services
sessions --> invocations
services --> core
%% Styles
classDef sg fill:#5028C8,font-weight:bold,stroke-width:2,color:#fff,stroke:#14141A
classDef default stroke-width:2px,stroke:#F6B314,color:#fff,fill:#14141A
class apps,webapi,invoke,core sg
```
## Applications
Applications are built on top of the invoke framework. They should construct `invoker` and then interact through it. They should avoid interacting directly with core code in order to support a variety of configurations.
### Web UI
The Web UI is built on top of an HTTP API built with [FastAPI](https://fastapi.tiangolo.com/) and [Socket.IO](https://socket.io/). The frontend code is found in `/frontend` and the backend code is found in `/ldm/invoke/app/api_app.py` and `/ldm/invoke/app/api/`. The code is further organized as such:
| Component | Description |
| --- | --- |
| api_app.py | Sets up the API app, annotates the OpenAPI spec with additional data, and runs the API |
| dependencies | Creates all invoker services and the invoker, and provides them to the API |
| events | An eventing system that could in the future be adapted to support horizontal scale-out |
| sockets | The Socket.IO interface - handles listening to and emitting session events (events are defined in the events service module) |
| routers | API definitions for different areas of API functionality |
### CLI
The CLI is built automatically from invocation metadata, and also supports invocation piping and auto-linking. Code is available in `/ldm/invoke/app/cli_app.py`.
## Invoke
The Invoke framework provides the interface to the underlying AI systems and is built with flexibility and extensibility in mind. There are four major concepts: invoker, sessions, invocations, and services.
### Invoker
The invoker (`/ldm/invoke/app/services/invoker.py`) is the primary interface through which applications interact with the framework. Its primary purpose is to create, manage, and invoke sessions. It also maintains two sets of services:
- **invocation services**, which are used by invocations to interact with core functionality.
- **invoker services**, which are used by the invoker to manage sessions and manage the invocation queue.
### Sessions
Invocations and links between them form a graph, which is maintained in a session. Sessions can be queued for invocation, which will execute their graph (either the next ready invocation, or all invocations). Sessions also maintain execution history for the graph (including storage of any outputs). An invocation may be added to a session at any time, and there is capability to add and entire graph at once, as well as to automatically link new invocations to previous invocations. Invocations can not be deleted or modified once added.
The session graph does not support looping. This is left as an application problem to prevent additional complexity in the graph.
### Invocations
Invocations represent individual units of execution, with inputs and outputs. All invocations are located in `/ldm/invoke/app/invocations`, and are all automatically discovered and made available in the applications. These are the primary way to expose new functionality in Invoke.AI, and the [implementation guide](INVOCATIONS.md) explains how to add new invocations.
### Services
Services provide invocations access AI Core functionality and other necessary functionality (e.g. image storage). These are available in `/ldm/invoke/app/services`. As a general rule, new services should provide an interface as an abstract base class, and may provide a lightweight local implementation by default in their module. The goal for all services should be to enable the usage of different implementations (e.g. using cloud storage for image storage), but should not load any module dependencies unless that implementation has been used (i.e. don't import anything that won't be used, especially if it's expensive to import).
## AI Core
The AI Core is represented by the rest of the code base (i.e. the code outside of `/ldm/invoke/app/`).

View File

@ -0,0 +1,105 @@
# Invocations
Invocations represent a single operation, its inputs, and its outputs. These operations and their outputs can be chained together to generate and modify images.
## Creating a new invocation
To create a new invocation, either find the appropriate module file in `/ldm/invoke/app/invocations` to add your invocation to, or create a new one in that folder. All invocations in that folder will be discovered and made available to the CLI and API automatically. Invocations make use of [typing](https://docs.python.org/3/library/typing.html) and [pydantic](https://pydantic-docs.helpmanual.io/) for validation and integration into the CLI and API.
An invocation looks like this:
```py
class UpscaleInvocation(BaseInvocation):
"""Upscales an image."""
type: Literal['upscale'] = 'upscale'
# Inputs
image: Union[ImageField,None] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
level: Literal[2,4] = Field(default=2, description = "The upscale level")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name)
results = context.services.generate.upscale_and_reconstruct(
image_list = [[image, 0]],
upscale = (self.level, self.strength),
strength = 0.0, # GFPGAN strength
save_original = False,
image_callback = None,
)
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name)
)
```
Each portion is important to implement correctly.
### Class definition and type
```py
class UpscaleInvocation(BaseInvocation):
"""Upscales an image."""
type: Literal['upscale'] = 'upscale'
```
All invocations must derive from `BaseInvocation`. They should have a docstring that declares what they do in a single, short line. They should also have a `type` with a type hint that's `Literal["command_name"]`, where `command_name` is what the user will type on the CLI or use in the API to create this invocation. The `command_name` must be unique. The `type` must be assigned to the value of the literal in the type hint.
### Inputs
```py
# Inputs
image: Union[ImageField,None] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
level: Literal[2,4] = Field(default=2, description="The upscale level")
```
Inputs consist of three parts: a name, a type hint, and a `Field` with default, description, and validation information. For example:
| Part | Value | Description |
| ---- | ----- | ----------- |
| Name | `strength` | This field is referred to as `strength` |
| Type Hint | `float` | This field must be of type `float` |
| Field | `Field(default=0.75, gt=0, le=1, description="The strength")` | The default value is `0.75`, the value must be in the range (0,1], and help text will show "The strength" for this field. |
Notice that `image` has type `Union[ImageField,None]`. The `Union` allows this field to be parsed with `None` as a value, which enables linking to previous invocations. All fields should either provide a default value or allow `None` as a value, so that they can be overwritten with a linked output from another invocation.
The special type `ImageField` is also used here. All images are passed as `ImageField`, which protects them from pydantic validation errors (since images only ever come from links).
Finally, note that for all linking, the `type` of the linked fields must match. If the `name` also matches, then the field can be **automatically linked** to a previous invocation by name and matching.
### Invoke Function
```py
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name)
results = context.services.generate.upscale_and_reconstruct(
image_list = [[image, 0]],
upscale = (self.level, self.strength),
strength = 0.0, # GFPGAN strength
save_original = False,
image_callback = None,
)
# Results are image and seed, unwrap for now
image_type = ImageType.RESULT
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name)
)
```
The `invoke` function is the last portion of an invocation. It is provided an `InvocationContext` which contains services to perform work as well as a `session_id` for use as needed. It should return a class with output values that derives from `BaseInvocationOutput`.
Before being called, the invocation will have all of its fields set from defaults, inputs, and finally links (overriding in that order).
Assume that this invocation may be running simultaneously with other invocations, may be running on another machine, or in other interesting scenarios. If you need functionality, please provide it as a service in the `InvocationServices` class, and make sure it can be overridden.
### Outputs
```py
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
type: Literal['image'] = 'image'
image: ImageField = Field(default=None, description="The output image")
```
Output classes look like an invocation class without the invoke method. Prefer to use an existing output class if available, and prefer to name inputs the same as outputs when possible, to promote automatic invocation linking.

View File

@ -1030,6 +1030,8 @@ class Generate:
image_callback=None,
prefix=None,
):
results = []
for r in image_list:
image, seed = r
try:
@ -1083,6 +1085,10 @@ class Generate:
else:
r[0] = image
results.append([image, seed])
return results
def apply_textmask(
self, image_path: str, prompt: str, callback, threshold: float = 0.5
):

View File

@ -0,0 +1,83 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from argparse import Namespace
import os
from ..services.processor import DefaultInvocationProcessor
from ..services.graph import GraphExecutionState
from ..services.sqlite import SqliteItemStorage
from ...globals import Globals
from ..services.image_storage import DiskImageStorage
from ..services.invocation_queue import MemoryInvocationQueue
from ..services.invocation_services import InvocationServices
from ..services.invoker import Invoker, InvokerServices
from ..services.generate_initializer import get_generate
from .events import FastAPIEventService
# TODO: is there a better way to achieve this?
def check_internet()->bool:
'''
Return true if the internet is reachable.
It does this by pinging huggingface.co.
'''
import urllib.request
host = 'http://huggingface.co'
try:
urllib.request.urlopen(host,timeout=1)
return True
except:
return False
class ApiDependencies:
"""Contains and initializes all dependencies for the API"""
invoker: Invoker = None
@staticmethod
def initialize(
args,
config,
event_handler_id: int
):
Globals.try_patchmatch = args.patchmatch
Globals.always_use_cpu = args.always_use_cpu
Globals.internet_available = args.internet_available and check_internet()
Globals.disable_xformers = not args.xformers
Globals.ckpt_convert = args.ckpt_convert
# TODO: Use a logger
print(f'>> Internet connectivity is {Globals.internet_available}')
generate = get_generate(args, config)
events = FastAPIEventService(event_handler_id)
output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../outputs'))
images = DiskImageStorage(output_folder)
services = InvocationServices(
generate = generate,
events = events,
images = images
)
# TODO: build a file/path manager?
db_location = os.path.join(output_folder, 'invokeai.db')
invoker_services = InvokerServices(
queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor()
)
ApiDependencies.invoker = Invoker(services, invoker_services)
@staticmethod
def shutdown():
if ApiDependencies.invoker:
ApiDependencies.invoker.stop()

View File

@ -0,0 +1,54 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
from queue import Empty, Queue
from typing import Any
from fastapi_events.dispatcher import dispatch
from ..services.events import EventServiceBase
import threading
class FastAPIEventService(EventServiceBase):
event_handler_id: int
__queue: Queue
__stop_event: threading.Event
def __init__(self, event_handler_id: int) -> None:
self.event_handler_id = event_handler_id
self.__queue = Queue()
self.__stop_event = threading.Event()
asyncio.create_task(self.__dispatch_from_queue(stop_event = self.__stop_event))
super().__init__()
def stop(self, *args, **kwargs):
self.__stop_event.set()
self.__queue.put(None)
def dispatch(self, event_name: str, payload: Any) -> None:
self.__queue.put(dict(
event_name = event_name,
payload = payload
))
async def __dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""
while not stop_event.is_set():
try:
event = self.__queue.get(block = False)
if not event: # Probably stopping
continue
dispatch(
event.get('event_name'),
payload = event.get('payload'),
middleware_id = self.event_handler_id)
except Empty:
await asyncio.sleep(0.001)
pass
except asyncio.CancelledError as e:
raise e # Raise a proper error

View File

@ -0,0 +1,57 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
from fastapi import Path, UploadFile, Request
from fastapi.routing import APIRouter
from fastapi.responses import FileResponse, Response
from PIL import Image
from ...services.image_storage import ImageType
from ..dependencies import ApiDependencies
images_router = APIRouter(
prefix = '/v1/images',
tags = ['images']
)
@images_router.get('/{image_type}/{image_name}',
operation_id = 'get_image'
)
async def get_image(
image_type: ImageType = Path(description = "The type of image to get"),
image_name: str = Path(description = "The name of the image to get")
):
"""Gets a result"""
# TODO: This is not really secure at all. At least make sure only output results are served
filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
return FileResponse(filename)
@images_router.post('/uploads/',
operation_id = 'upload_image',
responses = {
201: {'description': 'The image was uploaded successfully'},
404: {'description': 'Session not found'}
})
async def upload_image(
file: UploadFile,
request: Request
):
if not file.content_type.startswith('image'):
return Response(status_code = 415)
contents = await file.read()
try:
im = Image.open(contents)
except:
# Error opening the image
return Response(status_code = 415)
filename = f'{str(int(datetime.now(timezone.utc).timestamp()))}.png'
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
return Response(
status_code=201,
headers = {
'Location': request.url_for('get_image', image_type=ImageType.UPLOAD, image_name=filename)
}
)

View File

@ -0,0 +1,232 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import List, Optional, Union, Annotated
from fastapi import Query, Path, Body
from fastapi.routing import APIRouter
from fastapi.responses import Response
from pydantic.fields import Field
from ...services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies
from ...invocations.baseinvocation import BaseInvocation
from ...services.graph import EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
from ...invocations import *
session_router = APIRouter(
prefix = '/v1/sessions',
tags = ['sessions']
)
@session_router.post('/',
operation_id = 'create_session',
responses = {
200: {"model": GraphExecutionState},
400: {'description': 'Invalid json'}
})
async def create_session(
graph: Optional[Graph] = Body(default = None, description = "The graph to initialize the session with")
) -> GraphExecutionState:
"""Creates a new session, optionally initializing it with an invocation graph"""
session = ApiDependencies.invoker.create_execution_state(graph)
return session
@session_router.get('/',
operation_id = 'list_sessions',
responses = {
200: {"model": PaginatedResults[GraphExecutionState]}
})
async def list_sessions(
page: int = Query(default = 0, description = "The page of results to get"),
per_page: int = Query(default = 10, description = "The number of results per page"),
query: str = Query(default = '', description = "The query string to search for")
) -> PaginatedResults[GraphExecutionState]:
"""Gets a list of sessions, optionally searching"""
if filter == '':
result = ApiDependencies.invoker.invoker_services.graph_execution_manager.list(page, per_page)
else:
result = ApiDependencies.invoker.invoker_services.graph_execution_manager.search(query, page, per_page)
return result
@session_router.get('/{session_id}',
operation_id = 'get_session',
responses = {
200: {"model": GraphExecutionState},
404: {'description': 'Session not found'}
})
async def get_session(
session_id: str = Path(description = "The id of the session to get")
) -> GraphExecutionState:
"""Gets a session"""
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
else:
return session
@session_router.post('/{session_id}/nodes',
operation_id = 'add_node',
responses = {
200: {"model": str},
400: {'description': 'Invalid node or link'},
404: {'description': 'Session not found'}
}
)
async def add_node(
session_id: str = Path(description = "The id of the session"),
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The node to add")
) -> str:
"""Adds a node to the graph"""
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
try:
session.add_node(node)
ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session.id
except NodeAlreadyExecutedError:
return Response(status_code = 400)
except IndexError:
return Response(status_code = 400)
@session_router.put('/{session_id}/nodes/{node_path}',
operation_id = 'update_node',
responses = {
200: {"model": GraphExecutionState},
400: {'description': 'Invalid node or link'},
404: {'description': 'Session not found'}
}
)
async def update_node(
session_id: str = Path(description = "The id of the session"),
node_path: str = Path(description = "The path to the node in the graph"),
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body(description = "The new node")
) -> GraphExecutionState:
"""Updates a node in the graph and removes all linked edges"""
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
try:
session.update_node(node_path, node)
ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
return Response(status_code = 400)
except IndexError:
return Response(status_code = 400)
@session_router.delete('/{session_id}/nodes/{node_path}',
operation_id = 'delete_node',
responses = {
200: {"model": GraphExecutionState},
400: {'description': 'Invalid node or link'},
404: {'description': 'Session not found'}
}
)
async def delete_node(
session_id: str = Path(description = "The id of the session"),
node_path: str = Path(description = "The path to the node to delete")
) -> GraphExecutionState:
"""Deletes a node in the graph and removes all linked edges"""
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
try:
session.delete_node(node_path)
ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
return Response(status_code = 400)
except IndexError:
return Response(status_code = 400)
@session_router.post('/{session_id}/edges',
operation_id = 'add_edge',
responses = {
200: {"model": GraphExecutionState},
400: {'description': 'Invalid node or link'},
404: {'description': 'Session not found'}
}
)
async def add_edge(
session_id: str = Path(description = "The id of the session"),
edge: tuple[EdgeConnection, EdgeConnection] = Body(description = "The edge to add")
) -> GraphExecutionState:
"""Adds an edge to the graph"""
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
try:
session.add_edge(edge)
ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
return Response(status_code = 400)
except IndexError:
return Response(status_code = 400)
# TODO: the edge being in the path here is really ugly, find a better solution
@session_router.delete('/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}',
operation_id = 'delete_edge',
responses = {
200: {"model": GraphExecutionState},
400: {'description': 'Invalid node or link'},
404: {'description': 'Session not found'}
}
)
async def delete_edge(
session_id: str = Path(description = "The id of the session"),
from_node_id: str = Path(description = "The id of the node the edge is coming from"),
from_field: str = Path(description = "The field of the node the edge is coming from"),
to_node_id: str = Path(description = "The id of the node the edge is going to"),
to_field: str = Path(description = "The field of the node the edge is going to")
) -> GraphExecutionState:
"""Deletes an edge from the graph"""
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
try:
edge = (EdgeConnection(node_id = from_node_id, field = from_field), EdgeConnection(node_id = to_node_id, field = to_field))
session.delete_edge(edge)
ApiDependencies.invoker.invoker_services.graph_execution_manager.set(session) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
return Response(status_code = 400)
except IndexError:
return Response(status_code = 400)
@session_router.put('/{session_id}/invoke',
operation_id = 'invoke_session',
responses = {
200: {"model": None},
202: {'description': 'The invocation is queued'},
400: {'description': 'The session has no invocations ready to invoke'},
404: {'description': 'Session not found'}
})
async def invoke_session(
session_id: str = Path(description = "The id of the session to invoke"),
all: bool = Query(default = False, description = "Whether or not to invoke all remaining invocations")
) -> None:
"""Invokes a session"""
session = ApiDependencies.invoker.invoker_services.graph_execution_manager.get(session_id)
if session is None:
return Response(status_code = 404)
if session.is_complete():
return Response(status_code = 400)
ApiDependencies.invoker.invoke(session, invoke_all = all)
return Response(status_code=202)

View File

@ -0,0 +1,36 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from fastapi import FastAPI
from fastapi_socketio import SocketManager
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event
from ..services.events import EventServiceBase
class SocketIO:
__sio: SocketManager
def __init__(self, app: FastAPI):
self.__sio = SocketManager(app = app)
self.__sio.on('subscribe', handler=self._handle_sub)
self.__sio.on('unsubscribe', handler=self._handle_unsub)
local_handler.register(
event_name = EventServiceBase.session_event,
_func=self._handle_session_event
)
async def _handle_session_event(self, event: Event):
await self.__sio.emit(
event = event[1]['event'],
data = event[1]['data'],
room = event[1]['data']['graph_execution_state_id']
)
async def _handle_sub(self, sid, data, *args, **kwargs):
if 'session' in data:
self.__sio.enter_room(sid, data['session'])
# @app.sio.on('unsubscribe')
async def _handle_unsub(self, sid, data, *args, **kwargs):
if 'session' in data:
self.__sio.leave_room(sid, data['session'])

164
ldm/invoke/app/api_app.py Normal file
View File

@ -0,0 +1,164 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
from inspect import signature
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html
from fastapi.staticfiles import StaticFiles
from fastapi_events.middleware import EventHandlerASGIMiddleware
from fastapi_events.handlers.local import local_handler
from fastapi.middleware.cors import CORSMiddleware
from pydantic.schema import schema
import uvicorn
from .api.sockets import SocketIO
from .invocations import *
from .invocations.baseinvocation import BaseInvocation
from .api.routers import images, sessions
from .api.dependencies import ApiDependencies
from ..args import Args
# Create the app
# TODO: create this all in a method so configuration/etc. can be passed in?
app = FastAPI(
title = "Invoke AI",
docs_url = None,
redoc_url = None
)
# Add event handler
event_handler_id: int = id(app)
app.add_middleware(
EventHandlerASGIMiddleware,
handlers = [local_handler], # TODO: consider doing this in services to support different configurations
middleware_id = event_handler_id)
# Add CORS
# TODO: use configuration for this
origins = []
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
socket_io = SocketIO(app)
config = {}
# Add startup event to load dependencies
@app.on_event('startup')
async def startup_event():
args = Args()
config = args.parse_args()
ApiDependencies.initialize(
args = args,
config = config,
event_handler_id = event_handler_id
)
# Shut down threads
@app.on_event('shutdown')
async def shutdown_event():
ApiDependencies.shutdown()
# Include all routers
# TODO: REMOVE
# app.include_router(
# invocation.invocation_router,
# prefix = '/api')
app.include_router(
sessions.session_router,
prefix = '/api'
)
app.include_router(
images.images_router,
prefix = '/api'
)
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title = app.title,
description = "An API for invoking AI image operations",
version = "1.0.0",
routes = app.routes
)
# Add all outputs
all_invocations = BaseInvocation.get_invocations()
output_types = set()
output_type_titles = dict()
for invoker in all_invocations:
output_type = signature(invoker.invoke).return_annotation
output_types.add(output_type)
output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
for schema_key, output_schema in output_schemas['definitions'].items():
openapi_schema["components"]["schemas"][schema_key] = output_schema
# TODO: note that we assume the schema_key here is the TYPE.__name__
# This could break in some cases, figure out a better way to do it
output_type_titles[schema_key] = output_schema['title']
# Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations:
invoker_name = invoker.__name__
output_type = signature(invoker.invoke).return_annotation
output_type_title = output_type_titles[output_type.__name__]
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
outputs_ref = { '$ref': f'#/components/schemas/{output_type_title}' }
if 'additionalProperties' not in invoker_schema:
invoker_schema['additionalProperties'] = {}
invoker_schema['additionalProperties']['outputs'] = outputs_ref
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi
# Override API doc favicons
app.mount('/static', StaticFiles(directory='static/dream_web'), name='static')
@app.get("/docs", include_in_schema=False)
def overridden_swagger():
return get_swagger_ui_html(
openapi_url=app.openapi_url,
title=app.title,
swagger_favicon_url="/static/favicon.ico"
)
@app.get("/redoc", include_in_schema=False)
def overridden_redoc():
return get_redoc_html(
openapi_url=app.openapi_url,
title=app.title,
redoc_favicon_url="/static/favicon.ico"
)
def invoke_api():
# Start our own event loop for eventing usage
# TODO: determine if there's a better way to do this
loop = asyncio.new_event_loop()
config = uvicorn.Config(
app = app,
host = "0.0.0.0",
port = 9090,
loop = loop)
# Use access_log to turn off logging
server = uvicorn.Server(config)
loop.run_until_complete(server.serve())
if __name__ == "__main__":
invoke_api()

306
ldm/invoke/app/cli_app.py Normal file
View File

@ -0,0 +1,306 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import argparse
import shlex
import os
import time
from typing import Any, Dict, Iterable, Literal, Union, get_args, get_origin, get_type_hints
from pydantic import BaseModel
from pydantic.fields import Field
from .services.processor import DefaultInvocationProcessor
from .services.graph import EdgeConnection, GraphExecutionState
from .services.sqlite import SqliteItemStorage
from .invocations.image import ImageField
from .services.generate_initializer import get_generate
from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue
from .invocations.baseinvocation import BaseInvocation
from .services.invocation_services import InvocationServices
from .services.invoker import Invoker, InvokerServices
from .invocations import *
from ..args import Args
from .services.events import EventServiceBase
class InvocationCommand(BaseModel):
invocation: Union[BaseInvocation.get_invocations()] = Field(discriminator="type")
class InvalidArgs(Exception):
pass
def get_invocation_parser() -> argparse.ArgumentParser:
# Create invocation parser
parser = argparse.ArgumentParser()
def exit(*args, **kwargs):
raise InvalidArgs
parser.exit = exit
subparsers = parser.add_subparsers(dest='type')
invocation_parsers = dict()
# Add history parser
history_parser = subparsers.add_parser('history', help="Shows the invocation history")
history_parser.add_argument('count', nargs='?', default=5, type=int, help="The number of history entries to show")
# Add default parser
default_parser = subparsers.add_parser('default', help="Define a default value for all inputs with a specified name")
default_parser.add_argument('input', type=str, help="The input field")
default_parser.add_argument('value', help="The default value")
default_parser = subparsers.add_parser('reset_default', help="Resets a default value")
default_parser.add_argument('input', type=str, help="The input field")
# Create subparsers for each invocation
invocations = BaseInvocation.get_all_subclasses()
for invocation in invocations:
hints = get_type_hints(invocation)
cmd_name = get_args(hints['type'])[0]
command_parser = subparsers.add_parser(cmd_name, help=invocation.__doc__)
invocation_parsers[cmd_name] = command_parser
# Add linking capability
command_parser.add_argument('--link', '-l', action='append', nargs=3,
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)")
command_parser.add_argument('--link_node', '-ln', action='append',
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)")
# Convert all fields to arguments
fields = invocation.__fields__
for name, field in fields.items():
if name in ['id', 'type']:
continue
if get_origin(field.type_) == Literal:
allowed_values = get_args(field.type_)
allowed_types = set()
for val in allowed_values:
allowed_types.add(type(val))
allowed_types_list = list(allowed_types)
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list]
command_parser.add_argument(
f"--{name}",
dest=name,
type=field_type,
default=field.default,
choices = allowed_values,
help=field.field_info.description
)
else:
command_parser.add_argument(
f"--{name}",
dest=name,
type=field.type_,
default=field.default,
help=field.field_info.description
)
return parser
def get_invocation_command(invocation) -> str:
fields = invocation.__fields__.items()
type_hints = get_type_hints(type(invocation))
command = [invocation.type]
for name,field in fields:
if name in ['id', 'type']:
continue
# TODO: add links
# Skip image fields when serializing command
type_hint = type_hints.get(name) or None
if type_hint is ImageField or ImageField in get_args(type_hint):
continue
field_value = getattr(invocation, name)
field_default = field.default
if field_value != field_default:
if type_hint is str or str in get_args(type_hint):
command.append(f'--{name} "{field_value}"')
else:
command.append(f'--{name} {field_value}')
return ' '.join(command)
def get_graph_execution_history(graph_execution_state: GraphExecutionState) -> Iterable[str]:
"""Gets the history of fully-executed invocations for a graph execution"""
return (n for n in reversed(graph_execution_state.executed_history) if n in graph_execution_state.graph.nodes)
def generate_matching_edges(a: BaseInvocation, b: BaseInvocation) -> list[tuple[EdgeConnection, EdgeConnection]]:
"""Generates all possible edges between two invocations"""
atype = type(a)
btype = type(b)
aoutputtype = atype.get_output_type()
afields = get_type_hints(aoutputtype)
bfields = get_type_hints(btype)
matching_fields = set(afields.keys()).intersection(bfields.keys())
# Remove invalid fields
invalid_fields = set(['type', 'id'])
matching_fields = matching_fields.difference(invalid_fields)
edges = [(EdgeConnection(node_id = a.id, field = field), EdgeConnection(node_id = b.id, field = field)) for field in matching_fields]
return edges
def invoke_cli():
args = Args()
config = args.parse_args()
generate = get_generate(args, config)
# NOTE: load model on first use, uncomment to load at startup
# TODO: Make this a config option?
#generate.load_model()
events = EventServiceBase()
output_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../outputs'))
services = InvocationServices(
generate = generate,
events = events,
images = DiskImageStorage(output_folder)
)
# TODO: build a file/path manager?
db_location = os.path.join(output_folder, 'invokeai.db')
invoker_services = InvokerServices(
queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = db_location, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor()
)
invoker = Invoker(services, invoker_services)
session = invoker.create_execution_state()
parser = get_invocation_parser()
# Uncomment to print out previous sessions at startup
# print(invoker_services.session_manager.list())
# Defaults storage
defaults: Dict[str, Any] = dict()
while True:
try:
cmd_input = input("> ")
except KeyboardInterrupt:
# Ctrl-c exits
break
if cmd_input in ['exit','q']:
break;
if cmd_input in ['--help','help','h','?']:
parser.print_help()
continue
try:
# Refresh the state of the session
session = invoker.invoker_services.graph_execution_manager.get(session.id)
history = list(get_graph_execution_history(session))
# Split the command for piping
cmds = cmd_input.split('|')
start_id = len(history)
current_id = start_id
new_invocations = list()
for cmd in cmds:
# Parse args to create invocation
args = vars(parser.parse_args(shlex.split(cmd.strip())))
# Check for special commands
# TODO: These might be better as Pydantic models, similar to the invocations
if args['type'] == 'history':
history_count = args['count'] or 5
for i in range(min(history_count, len(history))):
entry_id = history[-1 - i]
entry = session.graph.get_node(entry_id)
print(f'{entry_id}: {get_invocation_command(entry.invocation)}')
continue
if args['type'] == 'reset_default':
if args['input'] in defaults:
del defaults[args['input']]
continue
if args['type'] == 'default':
field = args['input']
field_value = args['value']
defaults[field] = field_value
continue
# Override defaults
for field_name,field_default in defaults.items():
if field_name in args:
args[field_name] = field_default
# Parse invocation
args['id'] = current_id
command = InvocationCommand(invocation = args)
# Pipe previous command output (if there was a previous command)
edges = []
if len(history) > 0 or current_id != start_id:
from_id = history[0] if current_id == start_id else str(current_id - 1)
from_node = next(filter(lambda n: n[0].id == from_id, new_invocations))[0] if current_id != start_id else session.graph.get_node(from_id)
matching_edges = generate_matching_edges(from_node, command.invocation)
edges.extend(matching_edges)
# Parse provided links
if 'link_node' in args and args['link_node']:
for link in args['link_node']:
link_node = session.graph.get_node(link)
matching_edges = generate_matching_edges(link_node, command.invocation)
edges.extend(matching_edges)
if 'link' in args and args['link']:
for link in args['link']:
edges.append((EdgeConnection(node_id = link[1], field = link[0]), EdgeConnection(node_id = command.invocation.id, field = link[2])))
new_invocations.append((command.invocation, edges))
current_id = current_id + 1
# Command line was parsed successfully
# Add the invocations to the session
for invocation in new_invocations:
session.add_node(invocation[0])
for edge in invocation[1]:
session.add_edge(edge)
# Execute all available invocations
invoker.invoke(session, invoke_all = True)
while not session.is_complete():
# Wait some time
session = invoker.invoker_services.graph_execution_manager.get(session.id)
time.sleep(0.1)
except InvalidArgs:
print('Invalid command, use "help" to list commands')
continue
except SystemExit:
continue
invoker.stop()
if __name__ == "__main__":
invoke_cli()

View File

@ -0,0 +1,8 @@
import os
__all__ = []
dirname = os.path.dirname(os.path.abspath(__file__))
for f in os.listdir(dirname):
if f != "__init__.py" and os.path.isfile("%s/%s" % (dirname, f)) and f[-3:] == ".py":
__all__.append(f[:-3])

View File

@ -0,0 +1,74 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from abc import ABC, abstractmethod
from inspect import signature
from typing import get_args, get_type_hints
from pydantic import BaseModel, Field
from ..services.invocation_services import InvocationServices
class InvocationContext:
services: InvocationServices
graph_execution_state_id: str
def __init__(self, services: InvocationServices, graph_execution_state_id: str):
self.services = services
self.graph_execution_state_id = graph_execution_state_id
class BaseInvocationOutput(BaseModel):
"""Base class for all invocation outputs"""
# All outputs must include a type name like this:
# type: Literal['your_output_name']
@classmethod
def get_all_subclasses_tuple(cls):
subclasses = []
toprocess = [cls]
while len(toprocess) > 0:
next = toprocess.pop(0)
next_subclasses = next.__subclasses__()
subclasses.extend(next_subclasses)
toprocess.extend(next_subclasses)
return tuple(subclasses)
class BaseInvocation(ABC, BaseModel):
"""A node to process inputs and produce outputs.
May use dependency injection in __init__ to receive providers.
"""
# All invocations must include a type name like this:
# type: Literal['your_output_name']
@classmethod
def get_all_subclasses(cls):
subclasses = []
toprocess = [cls]
while len(toprocess) > 0:
next = toprocess.pop(0)
next_subclasses = next.__subclasses__()
subclasses.extend(next_subclasses)
toprocess.extend(next_subclasses)
return subclasses
@classmethod
def get_invocations(cls):
return tuple(BaseInvocation.get_all_subclasses())
@classmethod
def get_invocations_map(cls):
# Get the type strings out of the literals and into a dictionary
return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseInvocation.get_all_subclasses()))
@classmethod
def get_output_type(cls):
return signature(cls.invoke).return_annotation
@abstractmethod
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
"""Invoke with provided context and return outputs."""
pass
id: str = Field(description="The id of this node. Must be unique among all nodes.")

View File

@ -0,0 +1,42 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal
import numpy
from pydantic import Field
from PIL import Image, ImageOps
import cv2 as cv
from .image import ImageField, ImageOutput
from .baseinvocation import BaseInvocation, InvocationContext
from ..services.image_storage import ImageType
class CvInpaintInvocation(BaseInvocation):
"""Simple inpaint using opencv."""
type: Literal['cv_inpaint'] = 'cv_inpaint'
# Inputs
image: ImageField = Field(default=None, description="The image to inpaint")
mask: ImageField = Field(default=None, description="The mask to use when inpainting")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name)
mask = context.services.images.get(self.mask.image_type, self.mask.image_name)
# Convert to cv image/mask
# TODO: consider making these utility functions
cv_image = cv.cvtColor(numpy.array(image.convert('RGB')), cv.COLOR_RGB2BGR)
cv_mask = numpy.array(ImageOps.invert(mask))
# Inpaint
cv_inpainted = cv.inpaint(cv_image, cv_mask, 3, cv.INPAINT_TELEA)
# Convert back to Pillow
# TODO: consider making a utility function
image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB))
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, image_inpainted)
return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name)
)

View File

@ -0,0 +1,160 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
from typing import Any, Literal, Optional, Union
import numpy as np
from pydantic import Field
from PIL import Image
from skimage.exposure.histogram_matching import match_histograms
from .image import ImageField, ImageOutput
from .baseinvocation import BaseInvocation, InvocationContext
from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
SAMPLER_NAME_VALUES = Literal["ddim","plms","k_lms","k_dpm_2","k_dpm_2_a","k_euler","k_euler_a","k_heun"]
# Text to image
class TextToImageInvocation(BaseInvocation):
"""Generates an image using text2img."""
type: Literal['txt2img'] = 'txt2img'
# Inputs
# TODO: consider making prompt optional to enable providing prompt through a link
prompt: Optional[str] = Field(description="The prompt to generate an image from")
seed: int = Field(default=-1, ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image")
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image")
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt")
sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use")
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams")
model: str = Field(default='', description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation")
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(self, context: InvocationContext, sample: Any = None, step: int = 0) -> None:
context.services.events.emit_generator_progress(
context.graph_execution_state_id, self.id, step, float(step) / float(self.steps)
)
def invoke(self, context: InvocationContext) -> ImageOutput:
def step_callback(sample, step = 0):
self.dispatch_progress(context, sample, step)
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
if self.model is None or self.model == '':
self.model = context.services.generate.model_name
# Set the model (if already cached, this does nothing)
context.services.generate.set_model(self.model)
results = context.services.generate.prompt2image(
prompt = self.prompt,
step_callback = step_callback,
**self.dict(exclude = {'prompt'}) # Shorthand for passing all of the parameters above manually
)
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
# TODO: can this return multiple results? Should it?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name)
)
class ImageToImageInvocation(TextToImageInvocation):
"""Generates an image using img2img."""
type: Literal['img2img'] = 'img2img'
# Inputs
image: Union[ImageField,None] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the original image")
fit: bool = Field(default=True, description="Whether or not the result should be fit to the aspect ratio of the input image")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = None if self.image is None else context.services.images.get(self.image.image_type, self.image.image_name)
mask = None
def step_callback(sample, step = 0):
self.dispatch_progress(context, sample, step)
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
if self.model is None or self.model == '':
self.model = context.services.generate.model_name
# Set the model (if already cached, this does nothing)
context.services.generate.set_model(self.model)
results = context.services.generate.prompt2image(
prompt = self.prompt,
init_img = image,
init_mask = mask,
step_callback = step_callback,
**self.dict(exclude = {'prompt','image','mask'}) # Shorthand for passing all of the parameters above manually
)
result_image = results[0][0]
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
# TODO: can this return multiple results? Should it?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, result_image)
return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name)
)
class InpaintInvocation(ImageToImageInvocation):
"""Generates an image using inpaint."""
type: Literal['inpaint'] = 'inpaint'
# Inputs
mask: Union[ImageField,None] = Field(description="The mask")
inpaint_replace: float = Field(default=0.0, ge=0.0, le=1.0, description="The amount by which to replace masked areas with latent noise")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = None if self.image is None else context.services.images.get(self.image.image_type, self.image.image_name)
mask = None if self.mask is None else context.services.images.get(self.mask.image_type, self.mask.image_name)
def step_callback(sample, step = 0):
self.dispatch_progress(context, sample, step)
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
if self.model is None or self.model == '':
self.model = context.services.generate.model_name
# Set the model (if already cached, this does nothing)
context.services.generate.set_model(self.model)
results = context.services.generate.prompt2image(
prompt = self.prompt,
init_img = image,
init_mask = mask,
step_callback = step_callback,
**self.dict(exclude = {'prompt','image','mask'}) # Shorthand for passing all of the parameters above manually
)
result_image = results[0][0]
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
# TODO: can this return multiple results? Should it?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, result_image)
return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name)
)

View File

@ -0,0 +1,219 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
from typing import Literal, Optional
import numpy
from pydantic import Field, BaseModel
from PIL import Image, ImageOps, ImageFilter
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
image_type: str = Field(default=ImageType.RESULT, description="The type of the image")
image_name: Optional[str] = Field(default=None, description="The name of the image")
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
type: Literal['image'] = 'image'
image: ImageField = Field(default=None, description="The output image")
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
type: Literal['mask'] = 'mask'
mask: ImageField = Field(default=None, description="The output mask")
# TODO: this isn't really necessary anymore
class LoadImageInvocation(BaseInvocation):
"""Load an image from a filename and provide it as output."""
type: Literal['load_image'] = 'load_image'
# Inputs
image_type: ImageType = Field(description="The type of the image")
image_name: str = Field(description="The name of the image")
def invoke(self, context: InvocationContext) -> ImageOutput:
return ImageOutput(
image = ImageField(image_type = self.image_type, image_name = self.image_name)
)
class ShowImageInvocation(BaseInvocation):
"""Displays a provided image, and passes it forward in the pipeline."""
type: Literal['show_image'] = 'show_image'
# Inputs
image: ImageField = Field(default=None, description="The image to show")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name)
if image:
image.show()
# TODO: how to handle failure?
return ImageOutput(
image = ImageField(image_type = self.image.image_type, image_name = self.image.image_name)
)
class CropImageInvocation(BaseInvocation):
"""Crops an image to a specified box. The box can be outside of the image."""
type: Literal['crop'] = 'crop'
# Inputs
image: ImageField = Field(default=None, description="The image to crop")
x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name)
image_crop = Image.new(mode = 'RGBA', size = (self.width, self.height), color = (0, 0, 0, 0))
image_crop.paste(image, (-self.x, -self.y))
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, image_crop)
return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name)
)
class PasteImageInvocation(BaseInvocation):
"""Pastes an image into another image."""
type: Literal['paste'] = 'paste'
# Inputs
base_image: ImageField = Field(default=None, description="The base image")
image: ImageField = Field(default=None, description="The image to paste")
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get(self.base_image.image_type, self.base_image.image_name)
image = context.services.images.get(self.image.image_type, self.image.image_name)
mask = None if self.mask is None else ImageOps.invert(services.images.get(self.mask.image_type, self.mask.image_name))
# TODO: probably shouldn't invert mask here... should user be required to do it?
min_x = min(0, self.x)
min_y = min(0, self.y)
max_x = max(base_image.width, image.width + self.x)
max_y = max(base_image.height, image.height + self.y)
new_image = Image.new(mode = 'RGBA', size = (max_x - min_x, max_y - min_y), color = (0, 0, 0, 0))
new_image.paste(base_image, (abs(min_x), abs(min_y)))
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask = mask)
image_type = ImageType.RESULT
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, new_image)
return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name)
)
class MaskFromAlphaInvocation(BaseInvocation):
"""Extracts the alpha channel of an image as a mask."""
type: Literal['tomask'] = 'tomask'
# Inputs
image: ImageField = Field(default=None, description="The image to create the mask from")
invert: bool = Field(default=False, description="Whether or not to invert the mask")
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name)
image_mask = image.split()[-1]
if self.invert:
image_mask = ImageOps.invert(image_mask)
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, image_mask)
return MaskOutput(
mask = ImageField(image_type = image_type, image_name = image_name)
)
class BlurInvocation(BaseInvocation):
"""Blurs an image"""
type: Literal['blur'] = 'blur'
# Inputs
image: ImageField = Field(default=None, description="The image to blur")
radius: float = Field(default=8.0, ge=0, description="The blur radius")
blur_type: Literal['gaussian', 'box'] = Field(default='gaussian', description="The type of blur")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name)
blur = ImageFilter.GaussianBlur(self.radius) if self.blur_type == 'gaussian' else ImageFilter.BoxBlur(self.radius)
blur_image = image.filter(blur)
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, blur_image)
return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name)
)
class LerpInvocation(BaseInvocation):
"""Linear interpolation of all pixels of an image"""
type: Literal['lerp'] = 'lerp'
# Inputs
image: ImageField = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name)
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
image_arr = image_arr * (self.max - self.min) + self.max
lerp_image = Image.fromarray(numpy.uint8(image_arr))
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, lerp_image)
return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name)
)
class InverseLerpInvocation(BaseInvocation):
"""Inverse linear interpolation of all pixels of an image"""
type: Literal['ilerp'] = 'ilerp'
# Inputs
image: ImageField = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name)
image_arr = numpy.asarray(image, dtype=numpy.float32)
image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, ilerp_image)
return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name)
)

View File

@ -0,0 +1,9 @@
from typing import Literal
from pydantic.fields import Field
from .baseinvocation import BaseInvocationOutput
class PromptOutput(BaseInvocationOutput):
"""Base class for invocations that output a prompt"""
type: Literal['prompt'] = 'prompt'
prompt: str = Field(default=None, description="The output prompt")

View File

@ -0,0 +1,36 @@
from datetime import datetime, timezone
from typing import Literal, Union
from pydantic import Field
from .image import ImageField, ImageOutput
from .baseinvocation import BaseInvocation, InvocationContext
from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
class RestoreFaceInvocation(BaseInvocation):
"""Restores faces in an image."""
type: Literal['restore_face'] = 'restore_face'
# Inputs
image: Union[ImageField,None] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name)
results = context.services.generate.upscale_and_reconstruct(
image_list = [[image, 0]],
upscale = None,
strength = self.strength, # GFPGAN strength
save_original = False,
image_callback = None,
)
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name)
)

View File

@ -0,0 +1,38 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
from typing import Literal, Union
from pydantic import Field
from .image import ImageField, ImageOutput
from .baseinvocation import BaseInvocation, InvocationContext
from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
class UpscaleInvocation(BaseInvocation):
"""Upscales an image."""
type: Literal['upscale'] = 'upscale'
# Inputs
image: Union[ImageField,None] = Field(description="The input image", default=None)
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
level: Literal[2,4] = Field(default=2, description = "The upscale level")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name)
results = context.services.generate.upscale_and_reconstruct(
image_list = [[image, 0]],
upscale = (self.level, self.strength),
strength = 0.0, # GFPGAN strength
save_original = False,
image_callback = None,
)
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput(
image = ImageField(image_type = image_type, image_name = image_name)
)

View File

View File

@ -0,0 +1,78 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any, Dict
class EventServiceBase:
session_event: str = 'session_event'
"""Basic event bus, to have an empty stand-in when not needed"""
def dispatch(self, event_name: str, payload: Any) -> None:
pass
def __emit_session_event(self,
event_name: str,
payload: Dict) -> None:
self.dispatch(
event_name = EventServiceBase.session_event,
payload = dict(
event = event_name,
data = payload
)
)
# Define events here for every event in the system.
# This will make them easier to integrate until we find a schema generator.
def emit_generator_progress(self,
graph_execution_state_id: str,
invocation_id: str,
step: int,
percent: float
) -> None:
"""Emitted when there is generation progress"""
self.__emit_session_event(
event_name = 'generator_progress',
payload = dict(
graph_execution_state_id = graph_execution_state_id,
invocation_id = invocation_id,
step = step,
percent = percent
)
)
def emit_invocation_complete(self,
graph_execution_state_id: str,
invocation_id: str,
result: Dict
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_session_event(
event_name = 'invocation_complete',
payload = dict(
graph_execution_state_id = graph_execution_state_id,
invocation_id = invocation_id,
result = result
)
)
def emit_invocation_started(self,
graph_execution_state_id: str,
invocation_id: str
) -> None:
"""Emitted when an invocation has started"""
self.__emit_session_event(
event_name = 'invocation_started',
payload = dict(
graph_execution_state_id = graph_execution_state_id,
invocation_id = invocation_id
)
)
def emit_graph_execution_complete(self, graph_execution_state_id: str) -> None:
"""Emitted when a session has completed all invocations"""
self.__emit_session_event(
event_name = 'graph_execution_state_complete',
payload = dict(
graph_execution_state_id = graph_execution_state_id
)
)

View File

@ -0,0 +1,233 @@
from argparse import Namespace
import os
import sys
import traceback
from ...model_manager import ModelManager
from ...globals import Globals
from ....generate import Generate
import ldm.invoke
# TODO: most of this code should be split into individual services as the Generate.py code is deprecated
def get_generate(args, config) -> Generate:
if not args.conf:
config_file = os.path.join(Globals.root,'configs','models.yaml')
if not os.path.exists(config_file):
report_model_error(args, FileNotFoundError(f"The file {config_file} could not be found."))
print(f'>> {ldm.invoke.__app_name__}, version {ldm.invoke.__version__}')
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
import transformers # type: ignore
transformers.logging.set_verbosity_error()
import diffusers
diffusers.logging.set_verbosity_error()
# Loading Face Restoration and ESRGAN Modules
gfpgan,codeformer,esrgan = load_face_restoration(args)
# normalize the config directory relative to root
if not os.path.isabs(args.conf):
args.conf = os.path.normpath(os.path.join(Globals.root,args.conf))
if args.embeddings:
if not os.path.isabs(args.embedding_path):
embedding_path = os.path.normpath(os.path.join(Globals.root,args.embedding_path))
else:
embedding_path = args.embedding_path
else:
embedding_path = None
# migrate legacy models
ModelManager.migrate_models()
# load the infile as a list of lines
if args.infile:
try:
if os.path.isfile(args.infile):
infile = open(args.infile, 'r', encoding='utf-8')
elif args.infile == '-': # stdin
infile = sys.stdin
else:
raise FileNotFoundError(f'{args.infile} not found.')
except (FileNotFoundError, IOError) as e:
print(f'{e}. Aborting.')
sys.exit(-1)
# creating a Generate object:
try:
gen = Generate(
conf = args.conf,
model = args.model,
sampler_name = args.sampler_name,
embedding_path = embedding_path,
full_precision = args.full_precision,
precision = args.precision,
gfpgan = gfpgan,
codeformer = codeformer,
esrgan = esrgan,
free_gpu_mem = args.free_gpu_mem,
safety_checker = args.safety_checker,
max_loaded_models = args.max_loaded_models,
)
except (FileNotFoundError, TypeError, AssertionError) as e:
report_model_error(opt,e)
except (IOError, KeyError) as e:
print(f'{e}. Aborting.')
sys.exit(-1)
if args.seamless:
print(">> changed to seamless tiling mode")
# preload the model
try:
gen.load_model()
except KeyError:
pass
except Exception as e:
report_model_error(args, e)
# try to autoconvert new models
# autoimport new .ckpt files
if path := args.autoconvert:
gen.model_manager.autoconvert_weights(
conf_path=args.conf,
weights_directory=path,
)
return gen
def load_face_restoration(opt):
try:
gfpgan, codeformer, esrgan = None, None, None
if opt.restore or opt.esrgan:
from ldm.invoke.restoration import Restoration
restoration = Restoration()
if opt.restore:
gfpgan, codeformer = restoration.load_face_restore_models(opt.gfpgan_model_path)
else:
print('>> Face restoration disabled')
if opt.esrgan:
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
else:
print('>> Upscaling disabled')
else:
print('>> Face restoration and upscaling disabled')
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
return gfpgan,codeformer,esrgan
def report_model_error(opt:Namespace, e:Exception):
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
print('** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models.')
yes_to_all = os.environ.get('INVOKE_MODEL_RECONFIGURE')
if yes_to_all:
print('** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE')
else:
response = input('Do you want to run invokeai-configure script to select and/or reinstall models? [y] ')
if response.startswith(('n', 'N')):
return
print('invokeai-configure is launching....\n')
# Match arguments that were set on the CLI
# only the arguments accepted by the configuration script are parsed
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
config = ["--config", opt.conf] if opt.conf is not None else []
previous_args = sys.argv
sys.argv = [ 'invokeai-configure' ]
sys.argv.extend(root_dir)
sys.argv.extend(config)
if yes_to_all is not None:
for arg in yes_to_all.split():
sys.argv.append(arg)
from ldm.invoke.config import invokeai_configure
invokeai_configure.main()
# TODO: Figure out how to restart
# print('** InvokeAI will now restart')
# sys.argv = previous_args
# main() # would rather do a os.exec(), but doesn't exist?
# sys.exit(0)
# Temporary initializer for Generate until we migrate off of it
def old_get_generate(args, config) -> Generate:
# TODO: Remove the need for globals
from ldm.invoke.globals import Globals
# alert - setting globals here
Globals.root = os.path.expanduser(args.root_dir or os.environ.get('INVOKEAI_ROOT') or os.path.abspath('.'))
Globals.try_patchmatch = args.patchmatch
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
import transformers
transformers.logging.set_verbosity_error()
# Loading Face Restoration and ESRGAN Modules
gfpgan, codeformer, esrgan = None, None, None
try:
if config.restore or config.esrgan:
from ldm.invoke.restoration import Restoration
restoration = Restoration()
if config.restore:
gfpgan, codeformer = restoration.load_face_restore_models(config.gfpgan_model_path)
else:
print('>> Face restoration disabled')
if config.esrgan:
esrgan = restoration.load_esrgan(config.esrgan_bg_tile)
else:
print('>> Upscaling disabled')
else:
print('>> Face restoration and upscaling disabled')
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
# normalize the config directory relative to root
if not os.path.isabs(config.conf):
config.conf = os.path.normpath(os.path.join(Globals.root,config.conf))
if config.embeddings:
if not os.path.isabs(config.embedding_path):
embedding_path = os.path.normpath(os.path.join(Globals.root,config.embedding_path))
else:
embedding_path = None
# TODO: lazy-initialize this by wrapping it
try:
generate = Generate(
conf = config.conf,
model = config.model,
sampler_name = config.sampler_name,
embedding_path = embedding_path,
full_precision = config.full_precision,
precision = config.precision,
gfpgan = gfpgan,
codeformer = codeformer,
esrgan = esrgan,
free_gpu_mem = config.free_gpu_mem,
safety_checker = config.safety_checker,
max_loaded_models = config.max_loaded_models,
)
except (FileNotFoundError, TypeError, AssertionError):
#emergency_model_reconfigure() # TODO?
sys.exit(-1)
except (IOError, KeyError) as e:
print(f'{e}. Aborting.')
sys.exit(-1)
generate.free_gpu_mem = config.free_gpu_mem
return generate

View File

@ -0,0 +1,797 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import copy
import itertools
from types import NoneType
import uuid
import networkx as nx
from pydantic import BaseModel, validator
from pydantic.fields import Field
from typing import Any, Literal, Optional, Union, get_args, get_origin, get_type_hints, Annotated
from .invocation_services import InvocationServices
from ..invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from ..invocations import *
class EdgeConnection(BaseModel):
node_id: str = Field(description="The id of the node for this edge connection")
field: str = Field(description="The field for this connection")
def __eq__(self, other):
return (isinstance(other, self.__class__) and
getattr(other, 'node_id', None) == self.node_id and
getattr(other, 'field', None) == self.field)
def __hash__(self):
return hash(f'{self.node_id}.{self.field}')
def get_output_field(node: BaseInvocation, field: str) -> Any:
node_type = type(node)
node_outputs = get_type_hints(node_type.get_output_type())
node_output_field = node_outputs.get(field) or None
return node_output_field
def get_input_field(node: BaseInvocation, field: str) -> Any:
node_type = type(node)
node_inputs = get_type_hints(node_type)
node_input_field = node_inputs.get(field) or None
return node_input_field
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
if not from_type:
return False
if not to_type:
return False
# TODO: this is pretty forgiving on generic types. Clean that up (need to handle optionals and such)
if from_type and to_type:
# Ports are compatible
if (from_type == to_type or
from_type == Any or
to_type == Any or
Any in get_args(from_type) or
Any in get_args(to_type)):
return True
if from_type in get_args(to_type):
return True
if to_type in get_args(from_type):
return True
if not issubclass(from_type, to_type):
return False
else:
return False
return True
def are_connections_compatible(
from_node: BaseInvocation,
from_field: str,
to_node: BaseInvocation,
to_field: str) -> bool:
"""Determines if a connection between fields of two nodes is compatible."""
# TODO: handle iterators and collectors
from_node_field = get_output_field(from_node, from_field)
to_node_field = get_input_field(to_node, to_field)
return are_connection_types_compatible(from_node_field, to_node_field)
class NodeAlreadyInGraphError(Exception):
pass
class InvalidEdgeError(Exception):
pass
class NodeNotFoundError(Exception):
pass
class NodeAlreadyExecutedError(Exception):
pass
# TODO: Create and use an Empty output?
class GraphInvocationOutput(BaseInvocationOutput):
type: Literal['graph_output'] = 'graph_output'
# TODO: Fill this out and move to invocations
class GraphInvocation(BaseInvocation):
type: Literal['graph'] = 'graph'
# TODO: figure out how to create a default here
graph: 'Graph' = Field(description="The graph to run", default=None)
def invoke(self, context: InvocationContext) -> GraphInvocationOutput:
"""Invoke with provided services and return outputs."""
return GraphInvocationOutput()
class IterateInvocationOutput(BaseInvocationOutput):
"""Used to connect iteration outputs. Will be expanded to a specific output."""
type: Literal['iterate_output'] = 'iterate_output'
item: Any = Field(description="The item being iterated over")
# TODO: Fill this out and move to invocations
class IterateInvocation(BaseInvocation):
type: Literal['iterate'] = 'iterate'
collection: list[Any] = Field(description="The list of items to iterate over", default_factory=list)
index: int = Field(description="The index, will be provided on executed iterators", default=0)
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
"""Produces the outputs as values"""
return IterateInvocationOutput(item = self.collection[self.index])
class CollectInvocationOutput(BaseInvocationOutput):
type: Literal['collect_output'] = 'collect_output'
collection: list[Any] = Field(description="The collection of input items")
class CollectInvocation(BaseInvocation):
"""Collects values into a collection"""
type: Literal['collect'] = 'collect'
item: Any = Field(description="The item to collect (all inputs must be of the same type)", default=None)
collection: list[Any] = Field(description="The collection, will be provided on execution", default_factory=list)
def invoke(self, context: InvocationContext) -> CollectInvocationOutput:
"""Invoke with provided services and return outputs."""
return CollectInvocationOutput(collection = copy.copy(self.collection))
InvocationsUnion = Union[BaseInvocation.get_invocations()]
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()]
class Graph(BaseModel):
id: str = Field(description="The id of this graph", default_factory=uuid.uuid4)
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(description="The nodes in this graph", default_factory=dict)
edges: list[tuple[EdgeConnection,EdgeConnection]] = Field(description="The connections between nodes and their fields in this graph", default_factory=list)
def add_node(self, node: BaseInvocation) -> None:
"""Adds a node to a graph
:raises NodeAlreadyInGraphError: the node is already present in the graph.
"""
if node.id in self.nodes:
raise NodeAlreadyInGraphError()
self.nodes[node.id] = node
def _get_graph_and_node(self, node_path: str) -> tuple['Graph', str]:
"""Returns the graph and node id for a node path."""
# Materialized graphs may have nodes at the top level
if node_path in self.nodes:
return (self, node_path)
node_id = node_path if '.' not in node_path else node_path[:node_path.index('.')]
if node_id not in self.nodes:
raise NodeNotFoundError(f'Node {node_path} not found in graph')
node = self.nodes[node_id]
if not isinstance(node, GraphInvocation):
# There's more node path left but this isn't a graph - failure
raise NodeNotFoundError('Node path terminated early at a non-graph node')
return node.graph._get_graph_and_node(node_path[node_path.index('.')+1:])
def delete_node(self, node_path: str) -> None:
"""Deletes a node from a graph"""
try:
graph, node_id = self._get_graph_and_node(node_path)
# Delete edges for this node
input_edges = self._get_input_edges_and_graphs(node_path)
output_edges = self._get_output_edges_and_graphs(node_path)
for edge_graph,_,edge in input_edges:
edge_graph.delete_edge(edge)
for edge_graph,_,edge in output_edges:
edge_graph.delete_edge(edge)
del graph.nodes[node_id]
except NodeNotFoundError:
pass # Ignore, not doesn't exist (should this throw?)
def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
"""Adds an edge to a graph
:raises InvalidEdgeError: the provided edge is invalid.
"""
if self._is_edge_valid(edge) and edge not in self.edges:
self.edges.append(edge)
else:
raise InvalidEdgeError()
def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
"""Deletes an edge from a graph"""
try:
self.edges.remove(edge)
except KeyError:
pass
def is_valid(self) -> bool:
"""Validates the graph."""
# Validate all subgraphs
for gn in (n for n in self.nodes.values() if isinstance(n, GraphInvocation)):
if not gn.graph.is_valid():
return False
# Validate all edges reference nodes in the graph
node_ids = set([e[0].node_id for e in self.edges]+[e[1].node_id for e in self.edges])
if not all((self.has_node(node_id) for node_id in node_ids)):
return False
# Validate there are no cycles
g = self.nx_graph_flat()
if not nx.is_directed_acyclic_graph(g):
return False
# Validate all edge connections are valid
if not all((are_connections_compatible(
self.get_node(e[0].node_id), e[0].field,
self.get_node(e[1].node_id), e[1].field
) for e in self.edges)):
return False
# Validate all iterators
# TODO: may need to validate all iterators in subgraphs so edge connections in parent graphs will be available
if not all((self._is_iterator_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, IterateInvocation))):
return False
# Validate all collectors
# TODO: may need to validate all collectors in subgraphs so edge connections in parent graphs will be available
if not all((self._is_collector_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, CollectInvocation))):
return False
return True
def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool:
"""Validates that a new edge doesn't create a cycle in the graph"""
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
try:
from_node = self.get_node(edge[0].node_id)
to_node = self.get_node(edge[1].node_id)
except NodeNotFoundError:
return False
# Validate that an edge to this node+field doesn't already exist
input_edges = self._get_input_edges(edge[1].node_id, edge[1].field)
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
return False
# Validate that no cycles would be created
g = self.nx_graph_flat()
g.add_edge(edge[0].node_id, edge[1].node_id)
if not nx.is_directed_acyclic_graph(g):
return False
# Validate that the field types are compatible
if not are_connections_compatible(from_node, edge[0].field, to_node, edge[1].field):
return False
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
if isinstance(to_node, IterateInvocation) and edge[1].field == 'collection':
if not self._is_iterator_connection_valid(edge[1].node_id, new_input = edge[0]):
return False
# Validate if iterator input type matches output type (if this edge results in both being set)
if isinstance(from_node, IterateInvocation) and edge[0].field == 'item':
if not self._is_iterator_connection_valid(edge[0].node_id, new_output = edge[1]):
return False
# Validate if collector input type matches output type (if this edge results in both being set)
if isinstance(to_node, CollectInvocation) and edge[1].field == 'item':
if not self._is_collector_connection_valid(edge[1].node_id, new_input = edge[0]):
return False
# Validate if collector output type matches input type (if this edge results in both being set)
if isinstance(from_node, CollectInvocation) and edge[0].field == 'collection':
if not self._is_collector_connection_valid(edge[0].node_id, new_output = edge[1]):
return False
return True
def has_node(self, node_path: str) -> bool:
"""Determines whether or not a node exists in the graph."""
try:
n = self.get_node(node_path)
if n is not None:
return True
else:
return False
except NodeNotFoundError:
return False
def get_node(self, node_path: str) -> InvocationsUnion:
"""Gets a node from the graph using a node path."""
# Materialized graphs may have nodes at the top level
graph, node_id = self._get_graph_and_node(node_path)
return graph.nodes[node_id]
def _get_node_path(self, node_id: str, prefix: Optional[str] = None) -> str:
return node_id if prefix is None or prefix == '' else f'{prefix}.{node_id}'
def update_node(self, node_path: str, new_node: BaseInvocation) -> None:
"""Updates a node in the graph."""
graph, node_id = self._get_graph_and_node(node_path)
node = graph.nodes[node_id]
# Ensure the node type matches the new node
if type(node) != type(new_node):
raise TypeError(f'Node {node_path} is type {type(node)} but new node is type {type(new_node)}')
# Ensure the new id is either the same or is not in the graph
prefix = None if '.' not in node_path else node_path[:node_path.rindex('.')]
new_path = self._get_node_path(new_node.id, prefix = prefix)
if new_node.id != node.id and self.has_node(new_path):
raise NodeAlreadyInGraphError('Node with id {new_node.id} already exists in graph')
# Set the new node in the graph
graph.nodes[new_node.id] = new_node
if new_node.id != node.id:
input_edges = self._get_input_edges_and_graphs(node_path)
output_edges = self._get_output_edges_and_graphs(node_path)
# Delete node and all edges
graph.delete_node(node_path)
# Create new edges for each input and output
for graph,_,edge in input_edges:
# Remove the graph prefix from the node path
new_graph_node_path = new_node.id if '.' not in edge[1].node_id else f'{edge[1].node_id[edge[1].node_id.rindex("."):]}.{new_node.id}'
graph.add_edge((edge[0], EdgeConnection(node_id = new_graph_node_path, field = edge[1].field)))
for graph,_,edge in output_edges:
# Remove the graph prefix from the node path
new_graph_node_path = new_node.id if '.' not in edge[0].node_id else f'{edge[0].node_id[edge[0].node_id.rindex("."):]}.{new_node.id}'
graph.add_edge((EdgeConnection(node_id = new_graph_node_path, field = edge[0].field), edge[1]))
def _get_input_edges(self, node_path: str, field: Optional[str] = None) -> list[tuple[EdgeConnection,EdgeConnection]]:
"""Gets all input edges for a node"""
edges = self._get_input_edges_and_graphs(node_path)
# Filter to edges that match the field
filtered_edges = (e for e in edges if field is None or e[2][1].field == field)
# Create full node paths for each edge
return [(EdgeConnection(node_id = self._get_node_path(e[0].node_id, prefix = prefix), field=e[0].field), EdgeConnection(node_id = self._get_node_path(e[1].node_id, prefix = prefix), field=e[1].field)) for _,prefix,e in filtered_edges]
def _get_input_edges_and_graphs(self, node_path: str, prefix: Optional[str] = None) -> list[tuple['Graph', str, tuple[EdgeConnection,EdgeConnection]]]:
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
edges = list()
# Return any input edges that appear in this graph
edges.extend([(self, prefix, e) for e in self.edges if e[1].node_id == node_path])
node_id = node_path if '.' not in node_path else node_path[:node_path.index('.')]
node = self.nodes[node_id]
if isinstance(node, GraphInvocation):
graph = node.graph
graph_path = node.id if prefix is None or prefix == '' else self._get_node_path(node.id, prefix = prefix)
graph_edges = graph._get_input_edges_and_graphs(node_path[(len(node_id)+1):], prefix=graph_path)
edges.extend(graph_edges)
return edges
def _get_output_edges(self, node_path: str, field: str) -> list[tuple[EdgeConnection,EdgeConnection]]:
"""Gets all output edges for a node"""
edges = self._get_output_edges_and_graphs(node_path)
# Filter to edges that match the field
filtered_edges = (e for e in edges if e[2][0].field == field)
# Create full node paths for each edge
return [(EdgeConnection(node_id = self._get_node_path(e[0].node_id, prefix = prefix), field=e[0].field), EdgeConnection(node_id = self._get_node_path(e[1].node_id, prefix = prefix), field=e[1].field)) for _,prefix,e in filtered_edges]
def _get_output_edges_and_graphs(self, node_path: str, prefix: Optional[str] = None) -> list[tuple['Graph', str, tuple[EdgeConnection,EdgeConnection]]]:
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
edges = list()
# Return any input edges that appear in this graph
edges.extend([(self, prefix, e) for e in self.edges if e[0].node_id == node_path])
node_id = node_path if '.' not in node_path else node_path[:node_path.index('.')]
node = self.nodes[node_id]
if isinstance(node, GraphInvocation):
graph = node.graph
graph_path = node.id if prefix is None or prefix == '' else self._get_node_path(node.id, prefix = prefix)
graph_edges = graph._get_output_edges_and_graphs(node_path[(len(node_id)+1):], prefix=graph_path)
edges.extend(graph_edges)
return edges
def _is_iterator_connection_valid(self, node_path: str, new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None) -> bool:
inputs = list([e[0] for e in self._get_input_edges(node_path, 'collection')])
outputs = list([e[1] for e in self._get_output_edges(node_path, 'item')])
if new_input is not None:
inputs.append(new_input)
if new_output is not None:
outputs.append(new_output)
# Only one input is allowed for iterators
if len(inputs) > 1:
return False
# Get input and output fields (the fields linked to the iterator's input/output)
input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
# Input type must be a list
if get_origin(input_field) != list:
return False
# Validate that all outputs match the input type
input_field_item_type = get_args(input_field)[0]
if not all((are_connection_types_compatible(input_field_item_type, f) for f in output_fields)):
return False
return True
def _is_collector_connection_valid(self, node_path: str, new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None) -> bool:
inputs = list([e[0] for e in self._get_input_edges(node_path, 'item')])
outputs = list([e[1] for e in self._get_output_edges(node_path, 'collection')])
if new_input is not None:
inputs.append(new_input)
if new_output is not None:
outputs.append(new_output)
# Get input and output fields (the fields linked to the iterator's input/output)
input_fields = list([get_output_field(self.get_node(e.node_id), e.field) for e in inputs])
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
# Validate that all inputs are derived from or match a single type
input_field_types = set([t for input_field in input_fields for t in ([input_field] if get_origin(input_field) == None else get_args(input_field)) if t != NoneType]) # Get unique types
type_tree = nx.DiGraph()
type_tree.add_nodes_from(input_field_types)
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
type_degrees = type_tree.in_degree(type_tree.nodes)
if sum((t[1] == 0 for t in type_degrees)) != 1:
return False # There is more than one root type
# Get the input root type
input_root_type = next(t[0] for t in type_degrees if t[1] == 0)
# Verify that all outputs are lists
if not all((get_origin(f) == list for f in output_fields)):
return False
# Verify that all outputs match the input type (are a base class or the same class)
if not all((issubclass(input_root_type, get_args(f)[0]) for f in output_fields)):
return False
return True
def nx_graph(self) -> nx.DiGraph:
"""Returns a NetworkX DiGraph representing the layout of this graph"""
# TODO: Cache this?
g = nx.DiGraph()
g.add_nodes_from([n for n in self.nodes.keys()])
g.add_edges_from(set([(e[0].node_id, e[1].node_id) for e in self.edges]))
return g
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
"""Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)"""
g = nx_graph or nx.DiGraph()
# Add all nodes from this graph except graph/iteration nodes
g.add_nodes_from([self._get_node_path(n.id, prefix) for n in self.nodes.values() if not isinstance(n, GraphInvocation) and not isinstance(n, IterateInvocation)])
# Expand graph nodes
for sgn in (gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)):
sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
# TODO: figure out if iteration nodes need to be expanded
unique_edges = set([(e[0].node_id, e[1].node_id) for e in self.edges])
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
return g
class GraphExecutionState(BaseModel):
"""Tracks the state of a graph execution"""
id: str = Field(description="The id of the execution state", default_factory=uuid.uuid4)
# TODO: Store a reference to the graph instead of the actual graph?
graph: Graph = Field(description="The graph being executed")
# The graph of materialized nodes
execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes", default_factory=Graph)
# Nodes that have been executed
executed: set[str] = Field(description="The set of node ids that have been executed", default_factory=set)
executed_history: list[str] = Field(description="The list of node ids that have been executed, in order of execution", default_factory=list)
# The results of executed nodes
results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(description="The results of node executions", default_factory=dict)
# Map of prepared/executed nodes to their original nodes
prepared_source_mapping: dict[str, str] = Field(description="The map of prepared nodes to original graph nodes", default_factory=dict)
# Map of original nodes to prepared nodes
source_prepared_mapping: dict[str, set[str]] = Field(description="The map of original graph nodes to prepared nodes", default_factory=dict)
def next(self) -> BaseInvocation | None:
"""Gets the next node ready to execute."""
# TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes
# possibly with a timeout?
# If there are no prepared nodes, prepare some nodes
next_node = self._get_next_node()
if next_node is None:
prepared_id = self._prepare()
# TODO: prepare multiple nodes at once?
# while prepared_id is not None and not isinstance(self.graph.nodes[prepared_id], IterateInvocation):
# prepared_id = self._prepare()
if prepared_id is not None:
next_node = self._get_next_node()
# Get values from edges
if next_node is not None:
self._prepare_inputs(next_node)
# If next is still none, there's no next node, return None
return next_node
def complete(self, node_id: str, output: InvocationOutputsUnion):
"""Marks a node as complete"""
if node_id not in self.execution_graph.nodes:
return # TODO: log error?
# Mark node as executed
self.executed.add(node_id)
self.results[node_id] = output
# Check if source node is complete (all prepared nodes are complete)
source_node = self.prepared_source_mapping[node_id]
prepared_nodes = self.source_prepared_mapping[source_node]
if all([n in self.executed for n in prepared_nodes]):
self.executed.add(source_node)
self.executed_history.append(source_node)
def is_complete(self) -> bool:
"""Returns true if the graph is complete"""
return all((k in self.executed for k in self.graph.nodes))
def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
"""Prepares an iteration node and connects all edges, returning the new node id"""
node = self.graph.get_node(node_path)
self_iteration_count = -1
# If this is an iterator node, we must create a copy for each iteration
if isinstance(node, IterateInvocation):
# Get input collection edge (should error if there are no inputs)
input_collection_edge = next(iter(self.graph._get_input_edges(node_path, 'collection')))
input_collection_prepared_node_id = next(n[1] for n in iteration_node_map if n[0] == input_collection_edge[0].node_id)
input_collection_prepared_node_output = self.results[input_collection_prepared_node_id]
input_collection = getattr(input_collection_prepared_node_output, input_collection_edge[0].field)
self_iteration_count = len(input_collection)
new_nodes = list()
if self_iteration_count == 0:
# TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid.
return new_nodes
# Get all input edges
input_edges = self.graph._get_input_edges(node_path)
# Create new edges for this iteration
# For collect nodes, this may contain multiple inputs to the same field
new_edges = list()
for edge in input_edges:
for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge[0].node_id):
new_edge = (EdgeConnection(node_id = input_node_id, field = edge[0].field), EdgeConnection(node_id = '', field = edge[1].field))
new_edges.append(new_edge)
# Create a new node (or one for each iteration of this iterator)
for i in (range(self_iteration_count) if self_iteration_count > 0 else [-1]):
# Create a new node
new_node = copy.deepcopy(node)
# Create the node id (use a random uuid)
new_node.id = str(uuid.uuid4())
# Set the iteration index for iteration invocations
if isinstance(new_node, IterateInvocation):
new_node.index = i
# Add to execution graph
self.execution_graph.add_node(new_node)
self.prepared_source_mapping[new_node.id] = node_path
if node_path not in self.source_prepared_mapping:
self.source_prepared_mapping[node_path] = set()
self.source_prepared_mapping[node_path].add(new_node.id)
# Add new edges to execution graph
for edge in new_edges:
new_edge = (edge[0], EdgeConnection(node_id = new_node.id, field = edge[1].field))
self.execution_graph.add_edge(new_edge)
new_nodes.append(new_node.id)
return new_nodes
def _iterator_graph(self) -> nx.DiGraph:
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
g = self.graph.nx_graph()
collectors = (n for n in self.graph.nodes if isinstance(self.graph.nodes[n], CollectInvocation))
for c in collectors:
g.remove_edges_from(list(g.in_edges(c)))
return g
def _get_node_iterators(self, node_id: str) -> list[str]:
"""Gets iterators for a node"""
g = self._iterator_graph()
iterators = [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.nodes[n], IterateInvocation)]
return iterators
def _prepare(self) -> Optional[str]:
# Get flattened source graph
g = self.graph.nx_graph_flat()
# Find next unprepared node where all source nodes are executed
sorted_nodes = nx.topological_sort(g)
next_node_id = next((n for n in sorted_nodes if n not in self.source_prepared_mapping and all((e[0] in self.executed for e in g.in_edges(n)))), None)
if next_node_id == None:
return None
# Get all parents of the next node
next_node_parents = [e[0] for e in g.in_edges(next_node_id)]
# Create execution nodes
next_node = self.graph.get_node(next_node_id)
new_node_ids = list()
if isinstance(next_node, CollectInvocation):
# Collapse all iterator input mappings and create a single execution node for the collect invocation
all_iteration_mappings = list(itertools.chain(*(((s,p) for p in self.source_prepared_mapping[s]) for s in next_node_parents)))
#all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings)))
create_results = self._create_execution_node(next_node_id, all_iteration_mappings)
if create_results is not None:
new_node_ids.extend(create_results)
else: # Iterators or normal nodes
# Get all iterator combinations for this node
# Will produce a list of lists of prepared iterator nodes, from which results can be iterated
iterator_nodes = self._get_node_iterators(next_node_id)
iterator_nodes_prepared = [list(self.source_prepared_mapping[n]) for n in iterator_nodes]
iterator_node_prepared_combinations = list(itertools.product(*iterator_nodes_prepared))
# Select the correct prepared parents for each iteration
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
# TODO: Handle a node mapping to none
eg = self.execution_graph.nx_graph_flat()
prepared_parent_mappings = [[(n,self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations]
# Create execution node for each iteration
for iteration_mappings in prepared_parent_mappings:
create_results = self._create_execution_node(next_node_id, iteration_mappings)
if create_results is not None:
new_node_ids.extend(create_results)
return next(iter(new_node_ids), None)
def _get_iteration_node(self, source_node_path: str, graph: nx.DiGraph, execution_graph: nx.DiGraph, prepared_iterator_nodes: list[str]) -> Optional[str]:
"""Gets the prepared version of the specified source node that matches every iteration specified"""
prepared_nodes = self.source_prepared_mapping[source_node_path]
if len(prepared_nodes) == 1:
return next(iter(prepared_nodes))
# Check if the requested node is an iterator
prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None)
if prepared_iterator is not None:
return prepared_iterator
# Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source)
iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes]
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_path)]
return next((n for n in prepared_nodes if all(pit for pit in parent_iterators if nx.has_path(execution_graph, pit[0], n))), None)
def _get_next_node(self) -> Optional[BaseInvocation]:
g = self.execution_graph.nx_graph()
sorted_nodes = nx.topological_sort(g)
next_node = next((n for n in sorted_nodes if n not in self.executed), None)
if next_node is None:
return None
return self.execution_graph.nodes[next_node]
def _prepare_inputs(self, node: BaseInvocation):
input_edges = [e for e in self.execution_graph.edges if e[1].node_id == node.id]
if isinstance(node, CollectInvocation):
output_collection = [getattr(self.results[edge[0].node_id], edge[0].field) for edge in input_edges if edge[1].field == 'item']
setattr(node, 'collection', output_collection)
else:
for edge in input_edges:
output_value = getattr(self.results[edge[0].node_id], edge[0].field)
setattr(node, edge[1].field, output_value)
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool:
if not self._is_edge_valid(edge):
return False
# Invalid if destination has already been prepared or executed
if edge[1].node_id in self.source_prepared_mapping:
return False
# Otherwise, the edge is valid
return True
def _is_node_updatable(self, node_id: str) -> bool:
# The node is updatable as long as it hasn't been prepared or executed
return node_id not in self.source_prepared_mapping
def add_node(self, node: BaseInvocation) -> None:
self.graph.add_node(node)
def update_node(self, node_path: str, new_node: BaseInvocation) -> None:
if not self._is_node_updatable(node_path):
raise NodeAlreadyExecutedError(f'Node {node_path} has already been prepared or executed and cannot be updated')
self.graph.update_node(node_path, new_node)
def delete_node(self, node_path: str) -> None:
if not self._is_node_updatable(node_path):
raise NodeAlreadyExecutedError(f'Node {node_path} has already been prepared or executed and cannot be deleted')
self.graph.delete_node(node_path)
def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
if not self._is_node_updatable(edge[1].node_id):
raise NodeAlreadyExecutedError(f'Destination node {edge[1].node_id} has already been prepared or executed and cannot be linked to')
self.graph.add_edge(edge)
def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
if not self._is_node_updatable(edge[1].node_id):
raise NodeAlreadyExecutedError(f'Destination node {edge[1].node_id} has already been prepared or executed and cannot have a source edge deleted')
self.graph.delete_edge(edge)
GraphInvocation.update_forward_refs()

View File

@ -0,0 +1,104 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from abc import ABC, abstractmethod
from enum import Enum
import datetime
import os
from pathlib import Path
from queue import Queue
from typing import Dict
from PIL.Image import Image
from ...pngwriter import PngWriter
class ImageType(str, Enum):
RESULT = 'results'
INTERMEDIATE = 'intermediates'
UPLOAD = 'uploads'
class ImageStorageBase(ABC):
"""Responsible for storing and retrieving images."""
@abstractmethod
def get(self, image_type: ImageType, image_name: str) -> Image:
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def get_path(self, image_type: ImageType, image_name: str) -> str:
pass
@abstractmethod
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
pass
@abstractmethod
def delete(self, image_type: ImageType, image_name: str) -> None:
pass
def create_name(self, context_id: str, node_id: str) -> str:
return f'{context_id}_{node_id}_{str(int(datetime.datetime.now(datetime.timezone.utc).timestamp()))}.png'
class DiskImageStorage(ImageStorageBase):
"""Stores images on disk"""
__output_folder: str
__pngWriter: PngWriter
__cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[str, Image]
__max_cache_size: int
def __init__(self, output_folder: str):
self.__output_folder = output_folder
self.__pngWriter = PngWriter(output_folder)
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config
Path(output_folder).mkdir(parents=True, exist_ok=True)
# TODO: don't hard-code. get/save/delete should maybe take subpath?
for image_type in ImageType:
Path(os.path.join(output_folder, image_type)).mkdir(parents=True, exist_ok=True)
def get(self, image_type: ImageType, image_name: str) -> Image:
image_path = self.get_path(image_type, image_name)
cache_item = self.__get_cache(image_path)
if cache_item:
return cache_item
image = Image.open(image_path)
self.__set_cache(image_path, image)
return image
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(self, image_type: ImageType, image_name: str) -> str:
path = os.path.join(self.__output_folder, image_type, image_name)
return path
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
image_subpath = os.path.join(image_type, image_name)
self.__pngWriter.save_image_and_prompt_to_png(image, "", image_subpath, None) # TODO: just pass full path to png writer
image_path = self.get_path(image_type, image_name)
self.__set_cache(image_path, image)
def delete(self, image_type: ImageType, image_name: str) -> None:
image_path = self.get_path(image_type, image_name)
if os.path.exists(image_path):
os.remove(image_path)
if image_path in self.__cache:
del self.__cache[image_path]
def __get_cache(self, image_name: str) -> Image:
return None if image_name not in self.__cache else self.__cache[image_name]
def __set_cache(self, image_name: str, image: Image):
if not image_name in self.__cache:
self.__cache[image_name] = image
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
if len(self.__cache) > self.__max_cache_size:
cache_id = self.__cache_ids.get()
del self.__cache[cache_id]

View File

@ -0,0 +1,46 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from abc import ABC, abstractmethod
from queue import Queue
# TODO: make this serializable
class InvocationQueueItem:
#session_id: str
graph_execution_state_id: str
invocation_id: str
invoke_all: bool
def __init__(self,
#session_id: str,
graph_execution_state_id: str,
invocation_id: str,
invoke_all: bool = False):
#self.session_id = session_id
self.graph_execution_state_id = graph_execution_state_id
self.invocation_id = invocation_id
self.invoke_all = invoke_all
class InvocationQueueABC(ABC):
"""Abstract base class for all invocation queues"""
@abstractmethod
def get(self) -> InvocationQueueItem:
pass
@abstractmethod
def put(self, item: InvocationQueueItem|None) -> None:
pass
class MemoryInvocationQueue(InvocationQueueABC):
__queue: Queue
def __init__(self):
self.__queue = Queue()
def get(self) -> InvocationQueueItem:
return self.__queue.get()
def put(self, item: InvocationQueueItem|None) -> None:
self.__queue.put(item)

View File

@ -0,0 +1,20 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from .image_storage import ImageStorageBase
from .events import EventServiceBase
from ....generate import Generate
class InvocationServices():
"""Services that can be used by invocations"""
generate: Generate # TODO: wrap Generate, or split it up from model?
events: EventServiceBase
images: ImageStorageBase
def __init__(self,
generate: Generate,
events: EventServiceBase,
images: ImageStorageBase
):
self.generate = generate
self.events = events
self.images = images

View File

@ -0,0 +1,109 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from abc import ABC
from threading import Event, Thread
from .graph import Graph, GraphExecutionState
from .item_storage import ItemStorageABC
from ..invocations.baseinvocation import InvocationContext
from .invocation_services import InvocationServices
from .invocation_queue import InvocationQueueABC, InvocationQueueItem
class InvokerServices:
"""Services used by the Invoker for execution"""
queue: InvocationQueueABC
graph_execution_manager: ItemStorageABC[GraphExecutionState]
processor: 'InvocationProcessorABC'
def __init__(self,
queue: InvocationQueueABC,
graph_execution_manager: ItemStorageABC[GraphExecutionState],
processor: 'InvocationProcessorABC'):
self.queue = queue
self.graph_execution_manager = graph_execution_manager
self.processor = processor
class Invoker:
"""The invoker, used to execute invocations"""
services: InvocationServices
invoker_services: InvokerServices
def __init__(self,
services: InvocationServices, # Services used by nodes to perform invocations
invoker_services: InvokerServices # Services used by the invoker for orchestration
):
self.services = services
self.invoker_services = invoker_services
self._start()
def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> str|None:
"""Determines the next node to invoke and returns the id of the invoked node, or None if there are no nodes to execute"""
# Get the next invocation
invocation = graph_execution_state.next()
if not invocation:
return None
# Save the execution state
self.invoker_services.graph_execution_manager.set(graph_execution_state)
# Queue the invocation
print(f'queueing item {invocation.id}')
self.invoker_services.queue.put(InvocationQueueItem(
#session_id = session.id,
graph_execution_state_id = graph_execution_state.id,
invocation_id = invocation.id,
invoke_all = invoke_all
))
return invocation.id
def create_execution_state(self, graph: Graph|None = None) -> GraphExecutionState:
"""Creates a new execution state for the given graph"""
new_state = GraphExecutionState(graph = Graph() if graph is None else graph)
self.invoker_services.graph_execution_manager.set(new_state)
return new_state
def __start_service(self, service) -> None:
# Call start() method on any services that have it
start_op = getattr(service, 'start', None)
if callable(start_op):
start_op(self)
def __stop_service(self, service) -> None:
# Call stop() method on any services that have it
stop_op = getattr(service, 'stop', None)
if callable(stop_op):
stop_op(self)
def _start(self) -> None:
"""Starts the invoker. This is called automatically when the invoker is created."""
for service in vars(self.invoker_services):
self.__start_service(getattr(self.invoker_services, service))
for service in vars(self.services):
self.__start_service(getattr(self.services, service))
def stop(self) -> None:
"""Stops the invoker. A new invoker will have to be created to execute further."""
# First stop all services
for service in vars(self.services):
self.__stop_service(getattr(self.services, service))
for service in vars(self.invoker_services):
self.__stop_service(getattr(self.invoker_services, service))
self.invoker_services.queue.put(None)
class InvocationProcessorABC(ABC):
pass

View File

@ -0,0 +1,57 @@
from typing import Callable, TypeVar, Generic
from pydantic import BaseModel, Field
from pydantic.generics import GenericModel
from abc import ABC, abstractmethod
T = TypeVar('T', bound=BaseModel)
class PaginatedResults(GenericModel, Generic[T]):
"""Paginated results"""
items: list[T] = Field(description = "Items")
page: int = Field(description = "Current Page")
pages: int = Field(description = "Total number of pages")
per_page: int = Field(description = "Number of items per page")
total: int = Field(description = "Total number of items in result")
class ItemStorageABC(ABC, Generic[T]):
_on_changed_callbacks: list[Callable[[T], None]]
_on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None:
self._on_changed_callbacks = list()
self._on_deleted_callbacks = list()
"""Base item storage class"""
@abstractmethod
def get(self, item_id: str) -> T:
pass
@abstractmethod
def set(self, item: T) -> None:
pass
@abstractmethod
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
pass
@abstractmethod
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
pass
def on_changed(self, on_changed: Callable[[T], None]) -> None:
"""Register a callback for when an item is changed"""
self._on_changed_callbacks.append(on_changed)
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
"""Register a callback for when an item is deleted"""
self._on_deleted_callbacks.append(on_deleted)
def _on_changed(self, item: T) -> None:
for callback in self._on_changed_callbacks:
callback(item)
def _on_deleted(self, item_id: str) -> None:
for callback in self._on_deleted_callbacks:
callback(item_id)

View File

@ -0,0 +1,78 @@
from threading import Event, Thread
from ..invocations.baseinvocation import InvocationContext
from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker
class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread
__stop_event: Event
__invoker: Invoker
def start(self, invoker) -> None:
self.__invoker = invoker
self.__stop_event = Event()
self.__invoker_thread = Thread(
name = "invoker_processor",
target = self.__process,
kwargs = dict(stop_event = self.__stop_event)
)
self.__invoker_thread.daemon = True # TODO: probably better to just not use threads?
self.__invoker_thread.start()
def stop(self, *args, **kwargs) -> None:
self.__stop_event.set()
def __process(self, stop_event: Event):
try:
while not stop_event.is_set():
queue_item: InvocationQueueItem = self.__invoker.invoker_services.queue.get()
if not queue_item: # Probably stopping
continue
graph_execution_state = self.__invoker.invoker_services.graph_execution_manager.get(queue_item.graph_execution_state_id)
invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id)
# Send starting event
self.__invoker.services.events.emit_invocation_started(
graph_execution_state_id = graph_execution_state.id,
invocation_id = invocation.id
)
# Invoke
try:
outputs = invocation.invoke(InvocationContext(
services = self.__invoker.services,
graph_execution_state_id = graph_execution_state.id
))
# Save outputs and history
graph_execution_state.complete(invocation.id, outputs)
# Save the state changes
self.__invoker.invoker_services.graph_execution_manager.set(graph_execution_state)
# Send complete event
self.__invoker.services.events.emit_invocation_complete(
graph_execution_state_id = graph_execution_state.id,
invocation_id = invocation.id,
result = outputs.dict()
)
# Queue any further commands if invoking all
is_complete = graph_execution_state.is_complete()
if queue_item.invoke_all and not is_complete:
self.__invoker.invoke(graph_execution_state, invoke_all = True)
elif is_complete:
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)
except KeyboardInterrupt:
pass
except Exception as e:
# TODO: Log the error, mark the invocation as failed, and emit an event
print(f'Error invoking {invocation.id}: {e}')
pass
except KeyboardInterrupt:
... # Log something?

View File

@ -0,0 +1,119 @@
import sqlite3
from threading import Lock
from typing import Generic, TypeVar, Union, get_args
from pydantic import BaseModel, parse_raw_as
from .item_storage import ItemStorageABC, PaginatedResults
T = TypeVar('T', bound=BaseModel)
sqlite_memory = ':memory:'
class SqliteItemStorage(ItemStorageABC, Generic[T]):
_filename: str
_table_name: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_id_field: str
_lock: Lock
def __init__(self, filename: str, table_name: str, id_field: str = 'id'):
super().__init__()
self._filename = filename
self._table_name = table_name
self._id_field = id_field # TODO: validate that T has this field
self._lock = Lock()
self._conn = sqlite3.connect(self._filename, check_same_thread=False) # TODO: figure out a better threading solution
self._cursor = self._conn.cursor()
self._create_table()
def _create_table(self):
try:
self._lock.acquire()
self._cursor.execute(f'''CREATE TABLE IF NOT EXISTS {self._table_name} (
item TEXT,
id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);''')
self._cursor.execute(f'''CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);''')
finally:
self._lock.release()
def _parse_item(self, item: str) -> T:
item_type = get_args(self.__orig_class__)[0]
return parse_raw_as(item_type, item)
def set(self, item: T):
try:
self._lock.acquire()
self._cursor.execute(f'''INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);''', (item.json(),))
finally:
self._lock.release()
self._on_changed(item)
def get(self, id: str) -> Union[T, None]:
try:
self._lock.acquire()
self._cursor.execute(f'''SELECT item FROM {self._table_name} WHERE id = ?;''', (str(id),))
result = self._cursor.fetchone()
finally:
self._lock.release()
if not result:
return None
return self._parse_item(result[0])
def delete(self, id: str):
try:
self._lock.acquire()
self._cursor.execute(f'''DELETE FROM {self._table_name} WHERE id = ?;''', (str(id),))
finally:
self._lock.release()
self._on_deleted(id)
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
try:
self._lock.acquire()
self._cursor.execute(f'''SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;''', (per_page, page * per_page))
result = self._cursor.fetchall()
items = list(map(lambda r: self._parse_item(r[0]), result))
self._cursor.execute(f'''SELECT count(*) FROM {self._table_name};''')
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
return PaginatedResults[T](
items = items,
page = page,
pages = pageCount,
per_page = per_page,
total = count
)
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
try:
self._lock.acquire()
self._cursor.execute(f'''SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;''', (f'%{query}%', per_page, page * per_page))
result = self._cursor.fetchall()
items = list(map(lambda r: self._parse_item(r[0]), result))
self._cursor.execute(f'''SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;''', (f'%{query}%',))
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
return PaginatedResults[T](
items = items,
page = page,
pages = pageCount,
per_page = per_page,
total = count
)

View File

@ -45,6 +45,9 @@ dependencies = [
"einops",
"eventlet",
"facexlib",
"fastapi==0.85.0",
"fastapi-events==0.6.0",
"fastapi-socketio==0.0.9",
"flask==2.1.3",
"flask_cors==3.0.10",
"flask_socketio==5.3.0",
@ -66,6 +69,7 @@ dependencies = [
"prompt-toolkit",
"pypatchmatch",
"pyreadline3",
"python-multipart==0.0.5",
"pytorch-lightning==1.7.7",
"realesrgan",
"requests==2.28.2",
@ -80,6 +84,7 @@ dependencies = [
"torchvision>=0.14.1",
"torchmetrics",
"transformers~=4.25",
"uvicorn[standard]==0.20.0",
"windows-curses; sys_platform=='win32'",
]

20
scripts/invoke-new.py Normal file
View File

@ -0,0 +1,20 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import os
import sys
def main():
# Change working directory to the repo root
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
if '--web' in sys.argv:
from ldm.invoke.app.api_app import invoke_api
invoke_api()
else:
# TODO: Parse some top-level args here.
from ldm.invoke.app.cli_app import invoke_cli
invoke_cli()
if __name__ == '__main__':
main()

206
static/dream_web/test.html Normal file
View File

@ -0,0 +1,206 @@
<html lang="en">
<head>
<title>InvokeAI Test</title>
<meta charset="utf-8">
<link rel="icon" type="image/x-icon" href="static/dream_web/favicon.ico" />
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<!--<script src="config.js"></script>-->
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.0.1/socket.io.js"
integrity="sha512-q/dWJ3kcmjBLU4Qc47E4A9kTB4m3wuTY7vkFJDTZKjTs8jhyGQnaUrxa0Ytd0ssMZhbNua9hE+E7Qv1j+DyZwA=="
crossorigin="anonymous"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/fslightbox/3.0.9/index.js"></script>
<style>
body {
background:#334;
}
#gallery > a.image {
display:inline-block;
width:128px;
height:128px;
margin:8px;
background-size:contain;
background-repeat:no-repeat;
background-position:center;
}
.results {
border-color:green;
border-width:1px;
border-style:solid;
}
.intermediates {
border-color:yellow;
border-width:1px;
border-style:solid;
}
</style>
</head>
<body>
<button onclick="dostuff();">Test Invoke</button>
<div id="gallery"></div>
<div id="textlog">
<textarea id='log' rows=10 cols=60 autofocus>
</textarea>
</div>
<script type="text/javascript">
const socket_url = `ws://${window.location.host}`;
const socket = io(socket_url, {
path: "/ws/socket.io"
});
socket.on('connect', (data) => {
//socket.emit('subscribe', { 'session': 'WcBtYATwT92Mrb9zLgeyNw==' });
});
function addGalleryImage(src, className) {
let gallery = document.getElementById("gallery");
let div = document.createElement("a");
div.href = src;
div.setAttribute('data-fslightbox', '');
div.classList.add("image");
div.classList.add(className);
div.style.backgroundImage = `url('${src}')`;
gallery.appendChild(div);
refreshFsLightbox();
}
function log(event, data) {
document.getElementById("log").value += `${event} => ${JSON.stringify(data)}\n`;
if (data.result?.image?.image_type) {
addGalleryImage(`/api/v1/images/${data.result.image.image_type}/${data.result.image.image_name}`, data.result.image.image_type);
}
if (data.result?.mask?.image_type) {
addGalleryImage(`/api/v1/images/${data.result.mask.image_type}/${data.result.mask.image_name}`, data.result.mask.image_type);
}
console.log(`${event} => ${JSON.stringify(data)}`);
}
socket.on('generator_progress', (data) => log('generator_progress', data));
socket.on('invocation_complete', (data) => log('invocation_complete', data));
socket.on('invocation_started', (data) => log('invocation_started', data));
socket.on('session_complete', (data) => {
log('session_complete', data);
// NOTE: you may not want to unsubscribe if you plan to continue using this session,
// just make sure you unsubscribe eventually
socket.emit('unsubscribe', { 'session': data.session_id });
});
function dostuff() {
let prompt = "hyper detailed 4k cryengine 3D render of a cat in a dune buggy, trending on artstation, soft atmospheric lighting, volumetric lighting, cinematic still, golden hour, crepuscular rays, smooth [grainy]";
let strength = 0.95;
let sampler = 'keuler_a';
let steps = 50;
let seed = -1;
// Start building nodes
var id = 1;
var initialNode = {"id": id.toString(), "type": "txt2img", "prompt": prompt, "sampler": sampler, "steps": steps, "seed": seed};
id++;
var upscaleNode = {"id": id.toString(), "type": "show_image" };
id++
nodes = {};
nodes[initialNode.id] = initialNode;
nodes[upscaleNode.id] = upscaleNode;
links = [
[{ "node_id": initialNode.id, field: "image" },
{ "node_id": upscaleNode.id, field: "image" }]
];
// expandSize = 128;
// for (var i = 0; i < 6; ++i) {
// var i_seed = (seed == -1) ? -1 : seed + i + 1;
// var startid = id - 1;
// var offset = (i < 3) ? -expandSize : (3 * expandSize) + ((i - 3 + 1) * expandSize);
// nodes.push({"id": id.toString(), "type": "crop", "x": offset, "y": 0, "width": 512, "height": 512});
// let id_gen = id;
// links.push({"from_node": {"id": startid.toString(), "field": "image"},"to_node": {"id": id_gen.toString(),"field": "image"}});
// id++;
// nodes.push({"id": id.toString(), "type": "tomask"});
// let id_mask = id;
// links.push({"from_node": {"id": id_gen.toString(), "field": "image"},"to_node": {"id": id_mask.toString(),"field": "image"}});
// id++;
// nodes.push({"id": id.toString(), "type": "blur", "radius": 32});
// let id_mask_blur = id;
// links.push({"from_node": {"id": id_mask.toString(), "field": "mask"},"to_node": {"id": id_mask_blur.toString(),"field": "image"}});
// id++
// nodes.push({"id": id.toString(), "type": "ilerp", "min": 128, "max": 255});
// let id_ilerp = id;
// links.push({"from_node": {"id": id_mask_blur.toString(), "field": "image"},"to_node": {"id": id_ilerp.toString(),"field": "image"}});
// id++
// nodes.push({"id": id.toString(), "type": "cv_inpaint"});
// let id_cv_inpaint = id;
// links.push({"from_node": {"id": id_gen.toString(), "field": "image"},"to_node": {"id": id_cv_inpaint.toString(),"field": "image"}});
// links.push({"from_node": {"id": id_mask.toString(), "field": "mask"},"to_node": {"id": id_cv_inpaint.toString(),"field": "mask"}});
// id++;
// nodes.push({"id": id.toString(), "type": "img2img", "prompt": prompt, "strength": strength, "sampler": sampler, "steps": steps, "seed": i_seed, "color_match": true, "inpaint_replace": inpaint_replace});
// let id_i2i = id;
// links.push({"from_node": {"id": id_cv_inpaint.toString(), "field": "image"},"to_node": {"id": id_i2i.toString(),"field": "image"}});
// links.push({"from_node": {"id": id_ilerp.toString(), "field": "image"},"to_node": {"id": id_i2i.toString(),"field": "mask"}});
// id++;
// nodes.push({"id": id.toString(), "type": "paste", "x": offset, "y": 0});
// let id_paste = id;
// links.push({"from_node": {"id": startid.toString(), "field": "image"},"to_node": {"id": id_paste.toString(),"field": "base_image"}});
// links.push({"from_node": {"id": id_i2i.toString(), "field": "image"},"to_node": {"id": id_paste.toString(),"field": "image"}});
// links.push({"from_node": {"id": id_ilerp.toString(), "field": "image"},"to_node": {"id": id_paste.toString(),"field": "mask"}});
// id++;
// }
var graph = {
"nodes": nodes,
"edges": links
};
// var defaultGraph = {"nodes": [
// {"id": "1", "type": "txt2img", "prompt": prompt},
// {"id": "2", "type": "crop", "x": -256, "y": 128, "width": 512, "height": 512},
// {"id": "3", "type": "tomask"},
// {"id": "4", "type": "cv_inpaint"},
// {"id": "5", "type": "img2img", "prompt": prompt, "strength": 0.9},
// {"id": "6", "type": "paste", "x": -256, "y": 128},
// ],
// "links": [
// {"from_node": {"id": "1","field": "image"},"to_node": {"id": "2","field": "image"}},
// {"from_node": {"id": "2","field": "image"},"to_node": {"id": "3","field": "image"}},
// {"from_node": {"id": "2","field": "image"},"to_node": {"id": "4","field": "image"}},
// {"from_node": {"id": "3","field": "mask"},"to_node": {"id": "4","field": "mask"}},
// {"from_node": {"id": "4","field": "image"},"to_node": {"id": "5","field": "image"}},
// {"from_node": {"id": "3","field": "mask"},"to_node": {"id": "5","field": "mask"}},
// {"from_node": {"id": "1","field": "image"},"to_node": {"id": "6","field": "base_image"}},
// {"from_node": {"id": "5","field": "image"},"to_node": {"id": "6","field": "image"}}
// ]};
fetch('/api/v1/sessions/', {
method: 'POST',
headers: new Headers({'content-type': 'application/json'}),
body: JSON.stringify(graph)
}).then(response => response.json())
.then(data => {
sessionId = data.id;
socket.emit('subscribe', { 'session': sessionId });
fetch(`/api/v1/sessions/${sessionId}/invoke?all=true`, {
method: 'PUT',
headers: new Headers({'content-type': 'application/json'}),
body: ''
});
});
}
</script>
</body>
</html>

0
tests/__init__.py Normal file
View File

0
tests/nodes/__init__.py Normal file
View File

View File

@ -0,0 +1,114 @@
from .test_invoker import create_edge
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation
from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from ldm.invoke.app.services.invocation_services import InvocationServices
from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
from ldm.invoke.app.invocations.upscale import UpscaleInvocation
import pytest
@pytest.fixture
def simple_graph():
g = Graph()
g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi"))
g.add_node(ImageTestInvocation(id = "2"))
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
return g
@pytest.fixture
def mock_services():
# NOTE: none of these are actually called by the test invocations
return InvocationServices(generate = None, events = None, images = None)
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
n = g.next()
if n is None:
return (None, None)
print(f'invoking {n.id}: {type(n)}')
o = n.invoke(InvocationContext(services, "1"))
g.complete(n.id, o)
return (n, o)
def test_graph_state_executes_in_order(simple_graph, mock_services):
g = GraphExecutionState(graph = simple_graph)
n1 = invoke_next(g, mock_services)
n2 = invoke_next(g, mock_services)
n3 = g.next()
assert g.prepared_source_mapping[n1[0].id] == "1"
assert g.prepared_source_mapping[n2[0].id] == "2"
assert n3 is None
assert g.results[n1[0].id].prompt == n1[0].prompt
assert n2[0].prompt == n1[0].prompt
def test_graph_is_complete(simple_graph, mock_services):
g = GraphExecutionState(graph = simple_graph)
n1 = invoke_next(g, mock_services)
n2 = invoke_next(g, mock_services)
n3 = g.next()
assert g.is_complete()
def test_graph_is_not_complete(simple_graph, mock_services):
g = GraphExecutionState(graph = simple_graph)
n1 = invoke_next(g, mock_services)
n2 = g.next()
assert not g.is_complete()
# TODO: test completion with iterators/subgraphs
def test_graph_state_expands_iterator(mock_services):
graph = Graph()
test_prompts = ["Banana sushi", "Cat sushi"]
graph.add_node(PromptCollectionTestInvocation(id = "1", collection = list(test_prompts)))
graph.add_node(IterateInvocation(id = "2"))
graph.add_node(ImageTestInvocation(id = "3"))
graph.add_edge(create_edge("1", "collection", "2", "collection"))
graph.add_edge(create_edge("2", "item", "3", "prompt"))
g = GraphExecutionState(graph = graph)
n1 = invoke_next(g, mock_services)
n2 = invoke_next(g, mock_services)
n3 = invoke_next(g, mock_services)
n4 = invoke_next(g, mock_services)
n5 = invoke_next(g, mock_services)
assert g.prepared_source_mapping[n1[0].id] == "1"
assert g.prepared_source_mapping[n2[0].id] == "2"
assert g.prepared_source_mapping[n3[0].id] == "2"
assert g.prepared_source_mapping[n4[0].id] == "3"
assert g.prepared_source_mapping[n5[0].id] == "3"
assert isinstance(n4[0], ImageTestInvocation)
assert isinstance(n5[0], ImageTestInvocation)
prompts = [n4[0].prompt, n5[0].prompt]
assert sorted(prompts) == sorted(test_prompts)
def test_graph_state_collects(mock_services):
graph = Graph()
test_prompts = ["Banana sushi", "Cat sushi"]
graph.add_node(PromptCollectionTestInvocation(id = "1", collection = list(test_prompts)))
graph.add_node(IterateInvocation(id = "2"))
graph.add_node(PromptTestInvocation(id = "3"))
graph.add_node(CollectInvocation(id = "4"))
graph.add_edge(create_edge("1", "collection", "2", "collection"))
graph.add_edge(create_edge("2", "item", "3", "prompt"))
graph.add_edge(create_edge("3", "prompt", "4", "item"))
g = GraphExecutionState(graph = graph)
n1 = invoke_next(g, mock_services)
n2 = invoke_next(g, mock_services)
n3 = invoke_next(g, mock_services)
n4 = invoke_next(g, mock_services)
n5 = invoke_next(g, mock_services)
n6 = invoke_next(g, mock_services)
assert isinstance(n6[0], CollectInvocation)
assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts)

View File

@ -0,0 +1,85 @@
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation, TestEventService, create_edge, wait_until
from ldm.invoke.app.services.processor import DefaultInvocationProcessor
from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory
from ldm.invoke.app.services.invocation_queue import MemoryInvocationQueue
from ldm.invoke.app.services.invoker import Invoker, InvokerServices
from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from ldm.invoke.app.services.invocation_services import InvocationServices
from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
from ldm.invoke.app.invocations.upscale import UpscaleInvocation
import pytest
@pytest.fixture
def simple_graph():
g = Graph()
g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi"))
g.add_node(ImageTestInvocation(id = "2"))
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
return g
@pytest.fixture
def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations
return InvocationServices(generate = None, events = TestEventService(), images = None)
@pytest.fixture()
def mock_invoker_services() -> InvokerServices:
return InvokerServices(
queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor()
)
@pytest.fixture()
def mock_invoker(mock_services: InvocationServices, mock_invoker_services: InvokerServices) -> Invoker:
return Invoker(
services = mock_services,
invoker_services = mock_invoker_services
)
def test_can_create_graph_state(mock_invoker: Invoker):
g = mock_invoker.create_execution_state()
mock_invoker.stop()
assert g is not None
assert isinstance(g, GraphExecutionState)
def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph = simple_graph)
mock_invoker.stop()
assert g is not None
assert isinstance(g, GraphExecutionState)
assert g.graph == simple_graph
def test_can_invoke(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph = simple_graph)
invocation_id = mock_invoker.invoke(g)
assert invocation_id is not None
def has_executed_any(g: GraphExecutionState):
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
return len(g.executed) > 0
wait_until(lambda: has_executed_any(g), timeout = 5, interval = 1)
mock_invoker.stop()
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
assert len(g.executed) > 0
def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph = simple_graph)
invocation_id = mock_invoker.invoke(g, invoke_all = True)
assert invocation_id is not None
def has_executed_all(g: GraphExecutionState):
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
return g.is_complete()
wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1)
mock_invoker.stop()
g = mock_invoker.invoker_services.graph_execution_manager.get(g.id)
assert g.is_complete()

View File

@ -0,0 +1,501 @@
from ldm.invoke.app.invocations.image import *
from .test_nodes import ListPassThroughInvocation, PromptTestInvocation
from ldm.invoke.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
from ldm.invoke.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
from ldm.invoke.app.invocations.upscale import UpscaleInvocation
import pytest
# Helpers
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]:
return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field))
# Tests
def test_connections_are_compatible():
from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi")
from_field = "image"
to_node = UpscaleInvocation(id = "2")
to_field = "image"
result = are_connections_compatible(from_node, from_field, to_node, to_field)
assert result == True
def test_connections_are_incompatible():
from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi")
from_field = "image"
to_node = UpscaleInvocation(id = "2")
to_field = "strength"
result = are_connections_compatible(from_node, from_field, to_node, to_field)
assert result == False
def test_connections_incompatible_with_invalid_fields():
from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi")
from_field = "invalid_field"
to_node = UpscaleInvocation(id = "2")
to_field = "image"
# From field is invalid
result = are_connections_compatible(from_node, from_field, to_node, to_field)
assert result == False
# To field is invalid
from_field = "image"
to_field = "invalid_field"
result = are_connections_compatible(from_node, from_field, to_node, to_field)
assert result == False
def test_graph_can_add_node():
g = Graph()
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n)
assert n.id in g.nodes
def test_graph_fails_to_add_node_with_duplicate_id():
g = Graph()
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n)
n2 = TextToImageInvocation(id = "1", prompt = "Banana sushi the second")
with pytest.raises(NodeAlreadyInGraphError):
g.add_node(n2)
def test_graph_updates_node():
g = Graph()
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n)
n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi the second")
g.add_node(n2)
nu = TextToImageInvocation(id = "1", prompt = "Banana sushi updated")
g.update_node("1", nu)
assert g.nodes["1"].prompt == "Banana sushi updated"
def test_graph_fails_to_update_node_if_type_changes():
g = Graph()
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n)
n2 = UpscaleInvocation(id = "2")
g.add_node(n2)
nu = UpscaleInvocation(id = "1")
with pytest.raises(TypeError):
g.update_node("1", nu)
def test_graph_allows_non_conflicting_id_change():
g = Graph()
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n)
n2 = UpscaleInvocation(id = "2")
g.add_node(n2)
e1 = create_edge(n.id,"image",n2.id,"image")
g.add_edge(e1)
nu = TextToImageInvocation(id = "3", prompt = "Banana sushi")
g.update_node("1", nu)
with pytest.raises(NodeNotFoundError):
g.get_node("1")
assert g.get_node("3").prompt == "Banana sushi"
assert len(g.edges) == 1
assert (EdgeConnection(node_id = "3", field = "image"), EdgeConnection(node_id = "2", field = "image")) in g.edges
def test_graph_fails_to_update_node_id_if_conflict():
g = Graph()
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n)
n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi the second")
g.add_node(n2)
nu = TextToImageInvocation(id = "2", prompt = "Banana sushi")
with pytest.raises(NodeAlreadyInGraphError):
g.update_node("1", nu)
def test_graph_adds_edge():
g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e = create_edge(n1.id,"image",n2.id,"image")
g.add_edge(e)
assert e in g.edges
def test_graph_fails_to_add_edge_with_cycle():
g = Graph()
n1 = UpscaleInvocation(id = "1")
g.add_node(n1)
e = create_edge(n1.id,"image",n1.id,"image")
with pytest.raises(InvalidEdgeError):
g.add_edge(e)
def test_graph_fails_to_add_edge_with_long_cycle():
g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
n3 = UpscaleInvocation(id = "3")
g.add_node(n1)
g.add_node(n2)
g.add_node(n3)
e1 = create_edge(n1.id,"image",n2.id,"image")
e2 = create_edge(n2.id,"image",n3.id,"image")
e3 = create_edge(n3.id,"image",n2.id,"image")
g.add_edge(e1)
g.add_edge(e2)
with pytest.raises(InvalidEdgeError):
g.add_edge(e3)
def test_graph_fails_to_add_edge_with_missing_node_id():
g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e1 = create_edge("1","image","3","image")
e2 = create_edge("3","image","1","image")
with pytest.raises(InvalidEdgeError):
g.add_edge(e1)
with pytest.raises(InvalidEdgeError):
g.add_edge(e2)
def test_graph_fails_to_add_edge_when_destination_exists():
g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
n3 = UpscaleInvocation(id = "3")
g.add_node(n1)
g.add_node(n2)
g.add_node(n3)
e1 = create_edge(n1.id,"image",n2.id,"image")
e2 = create_edge(n1.id,"image",n3.id,"image")
e3 = create_edge(n2.id,"image",n3.id,"image")
g.add_edge(e1)
g.add_edge(e2)
with pytest.raises(InvalidEdgeError):
g.add_edge(e3)
def test_graph_fails_to_add_edge_with_mismatched_types():
g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e1 = create_edge("1","image","2","strength")
with pytest.raises(InvalidEdgeError):
g.add_edge(e1)
def test_graph_connects_collector():
g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi 2")
n3 = CollectInvocation(id = "3")
n4 = ListPassThroughInvocation(id = "4")
g.add_node(n1)
g.add_node(n2)
g.add_node(n3)
g.add_node(n4)
e1 = create_edge("1","image","3","item")
e2 = create_edge("2","image","3","item")
e3 = create_edge("3","collection","4","collection")
g.add_edge(e1)
g.add_edge(e2)
g.add_edge(e3)
# TODO: test that derived types mixed with base types are compatible
def test_graph_collector_invalid_with_varying_input_types():
g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n2 = PromptTestInvocation(id = "2", prompt = "banana sushi 2")
n3 = CollectInvocation(id = "3")
g.add_node(n1)
g.add_node(n2)
g.add_node(n3)
e1 = create_edge("1","image","3","item")
e2 = create_edge("2","prompt","3","item")
g.add_edge(e1)
with pytest.raises(InvalidEdgeError):
g.add_edge(e2)
def test_graph_collector_invalid_with_varying_input_output():
g = Graph()
n1 = PromptTestInvocation(id = "1", prompt = "Banana sushi")
n2 = PromptTestInvocation(id = "2", prompt = "Banana sushi 2")
n3 = CollectInvocation(id = "3")
n4 = ListPassThroughInvocation(id = "4")
g.add_node(n1)
g.add_node(n2)
g.add_node(n3)
g.add_node(n4)
e1 = create_edge("1","prompt","3","item")
e2 = create_edge("2","prompt","3","item")
e3 = create_edge("3","collection","4","collection")
g.add_edge(e1)
g.add_edge(e2)
with pytest.raises(InvalidEdgeError):
g.add_edge(e3)
def test_graph_collector_invalid_with_non_list_output():
g = Graph()
n1 = PromptTestInvocation(id = "1", prompt = "Banana sushi")
n2 = PromptTestInvocation(id = "2", prompt = "Banana sushi 2")
n3 = CollectInvocation(id = "3")
n4 = PromptTestInvocation(id = "4")
g.add_node(n1)
g.add_node(n2)
g.add_node(n3)
g.add_node(n4)
e1 = create_edge("1","prompt","3","item")
e2 = create_edge("2","prompt","3","item")
e3 = create_edge("3","collection","4","prompt")
g.add_edge(e1)
g.add_edge(e2)
with pytest.raises(InvalidEdgeError):
g.add_edge(e3)
def test_graph_connects_iterator():
g = Graph()
n1 = ListPassThroughInvocation(id = "1")
n2 = IterateInvocation(id = "2")
n3 = ImageToImageInvocation(id = "3", prompt = "Banana sushi")
g.add_node(n1)
g.add_node(n2)
g.add_node(n3)
e1 = create_edge("1","collection","2","collection")
e2 = create_edge("2","item","3","image")
g.add_edge(e1)
g.add_edge(e2)
# TODO: TEST INVALID ITERATOR SCENARIOS
def test_graph_iterator_invalid_if_multiple_inputs():
g = Graph()
n1 = ListPassThroughInvocation(id = "1")
n2 = IterateInvocation(id = "2")
n3 = ImageToImageInvocation(id = "3", prompt = "Banana sushi")
n4 = ListPassThroughInvocation(id = "4")
g.add_node(n1)
g.add_node(n2)
g.add_node(n3)
g.add_node(n4)
e1 = create_edge("1","collection","2","collection")
e2 = create_edge("2","item","3","image")
e3 = create_edge("4","collection","2","collection")
g.add_edge(e1)
g.add_edge(e2)
with pytest.raises(InvalidEdgeError):
g.add_edge(e3)
def test_graph_iterator_invalid_if_input_not_list():
g = Graph()
n1 = TextToImageInvocation(id = "1", promopt = "Banana sushi")
n2 = IterateInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e1 = create_edge("1","collection","2","collection")
with pytest.raises(InvalidEdgeError):
g.add_edge(e1)
def test_graph_iterator_invalid_if_output_and_input_types_different():
g = Graph()
n1 = ListPassThroughInvocation(id = "1")
n2 = IterateInvocation(id = "2")
n3 = PromptTestInvocation(id = "3", prompt = "Banana sushi")
g.add_node(n1)
g.add_node(n2)
g.add_node(n3)
e1 = create_edge("1","collection","2","collection")
e2 = create_edge("2","item","3","prompt")
g.add_edge(e1)
with pytest.raises(InvalidEdgeError):
g.add_edge(e2)
def test_graph_validates():
g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e1 = create_edge("1","image","2","image")
g.add_edge(e1)
assert g.is_valid() == True
def test_graph_invalid_if_edges_reference_missing_nodes():
g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
g.nodes[n1.id] = n1
e1 = create_edge("1","image","2","image")
g.edges.append(e1)
assert g.is_valid() == False
def test_graph_invalid_if_subgraph_invalid():
g = Graph()
n1 = GraphInvocation(id = "1")
n1.graph = Graph()
n1_1 = TextToImageInvocation(id = "2", prompt = "Banana sushi")
n1.graph.nodes[n1_1.id] = n1_1
e1 = create_edge("1","image","2","image")
n1.graph.edges.append(e1)
g.nodes[n1.id] = n1
assert g.is_valid() == False
def test_graph_invalid_if_has_cycle():
g = Graph()
n1 = UpscaleInvocation(id = "1")
n2 = UpscaleInvocation(id = "2")
g.nodes[n1.id] = n1
g.nodes[n2.id] = n2
e1 = create_edge("1","image","2","image")
e2 = create_edge("2","image","1","image")
g.edges.append(e1)
g.edges.append(e2)
assert g.is_valid() == False
def test_graph_invalid_with_invalid_connection():
g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
g.nodes[n1.id] = n1
g.nodes[n2.id] = n2
e1 = create_edge("1","image","2","strength")
g.edges.append(e1)
assert g.is_valid() == False
# TODO: Subgraph operations
def test_graph_gets_subgraph_node():
g = Graph()
n1 = GraphInvocation(id = "1")
n1.graph = Graph()
n1.graph.add_node
n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n1.graph.add_node(n1_1)
g.add_node(n1)
result = g.get_node('1.1')
assert result is not None
assert result.id == '1'
assert result == n1_1
def test_graph_fails_to_get_missing_subgraph_node():
g = Graph()
n1 = GraphInvocation(id = "1")
n1.graph = Graph()
n1.graph.add_node
n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n1.graph.add_node(n1_1)
g.add_node(n1)
with pytest.raises(NodeNotFoundError):
result = g.get_node('1.2')
def test_graph_fails_to_enumerate_non_subgraph_node():
g = Graph()
n1 = GraphInvocation(id = "1")
n1.graph = Graph()
n1.graph.add_node
n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n1.graph.add_node(n1_1)
g.add_node(n1)
n2 = UpscaleInvocation(id = "2")
g.add_node(n2)
with pytest.raises(NodeNotFoundError):
result = g.get_node('2.1')
def test_graph_gets_networkx_graph():
g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e = create_edge(n1.id,"image",n2.id,"image")
g.add_edge(e)
nxg = g.nx_graph()
assert '1' in nxg.nodes
assert '2' in nxg.nodes
assert ('1','2') in nxg.edges
# TODO: Graph serializes and deserializes
def test_graph_can_serialize():
g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e = create_edge(n1.id,"image",n2.id,"image")
g.add_edge(e)
# Not throwing on this line is sufficient
json = g.json()
def test_graph_can_deserialize():
g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e = create_edge(n1.id,"image",n2.id,"image")
g.add_edge(e)
json = g.json()
g2 = Graph.parse_raw(json)
assert g2 is not None
assert g2.nodes['1'] is not None
assert g2.nodes['2'] is not None
assert len(g2.edges) == 1
assert g2.edges[0][0].node_id == '1'
assert g2.edges[0][0].field == 'image'
assert g2.edges[0][1].node_id == '2'
assert g2.edges[0][1].field == 'image'
def test_graph_can_generate_schema():
# Not throwing on this line is sufficient
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
schema = Graph.schema_json(indent = 2)

92
tests/nodes/test_nodes.py Normal file
View File

@ -0,0 +1,92 @@
from typing import Any, Callable, Literal
from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from ldm.invoke.app.invocations.image import ImageField
from ldm.invoke.app.services.invocation_services import InvocationServices
from pydantic import Field
import pytest
# Define test invocations before importing anything that uses invocations
class ListPassThroughInvocationOutput(BaseInvocationOutput):
type: Literal['test_list_output'] = 'test_list_output'
collection: list[ImageField] = Field(default_factory=list)
class ListPassThroughInvocation(BaseInvocation):
type: Literal['test_list'] = 'test_list'
collection: list[ImageField] = Field(default_factory=list)
def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput:
return ListPassThroughInvocationOutput(collection = self.collection)
class PromptTestInvocationOutput(BaseInvocationOutput):
type: Literal['test_prompt_output'] = 'test_prompt_output'
prompt: str = Field(default = "")
class PromptTestInvocation(BaseInvocation):
type: Literal['test_prompt'] = 'test_prompt'
prompt: str = Field(default = "")
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
return PromptTestInvocationOutput(prompt = self.prompt)
class ImageTestInvocationOutput(BaseInvocationOutput):
type: Literal['test_image_output'] = 'test_image_output'
image: ImageField = Field()
class ImageTestInvocation(BaseInvocation):
type: Literal['test_image'] = 'test_image'
prompt: str = Field(default = "")
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
type: Literal['test_prompt_collection_output'] = 'test_prompt_collection_output'
collection: list[str] = Field(default_factory=list)
class PromptCollectionTestInvocation(BaseInvocation):
type: Literal['test_prompt_collection'] = 'test_prompt_collection'
collection: list[str] = Field()
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
from ldm.invoke.app.services.events import EventServiceBase
from ldm.invoke.app.services.graph import EdgeConnection
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]:
return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field))
class TestEvent:
event_name: str
payload: Any
def __init__(self, event_name: str, payload: Any):
self.event_name = event_name
self.payload = payload
class TestEventService(EventServiceBase):
events: list
def __init__(self):
super().__init__()
self.events = list()
def dispatch(self, event_name: str, payload: Any) -> None:
pass
def wait_until(condition: Callable[[], bool], timeout: int = 10, interval: float = 0.1) -> None:
import time
start_time = time.time()
while time.time() - start_time < timeout:
if condition():
return
time.sleep(interval)
raise TimeoutError("Condition not met")

112
tests/nodes/test_sqlite.py Normal file
View File

@ -0,0 +1,112 @@
from ldm.invoke.app.services.sqlite import SqliteItemStorage, sqlite_memory
from pydantic import BaseModel, Field
class TestModel(BaseModel):
id: str = Field(description = "ID")
name: str = Field(description = "Name")
def test_sqlite_service_can_create_and_get():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
db.set(TestModel(id = '1', name = 'Test'))
assert db.get('1') == TestModel(id = '1', name = 'Test')
def test_sqlite_service_can_list():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
db.set(TestModel(id = '1', name = 'Test'))
db.set(TestModel(id = '2', name = 'Test'))
db.set(TestModel(id = '3', name = 'Test'))
results = db.list()
assert results.page == 0
assert results.pages == 1
assert results.per_page == 10
assert results.total == 3
assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test'), TestModel(id = '3', name = 'Test')]
def test_sqlite_service_can_delete():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
db.set(TestModel(id = '1', name = 'Test'))
db.delete('1')
assert db.get('1') is None
def test_sqlite_service_calls_set_callback():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
called = False
def on_changed(item: TestModel):
nonlocal called
called = True
db.on_changed(on_changed)
db.set(TestModel(id = '1', name = 'Test'))
assert called
def test_sqlite_service_calls_delete_callback():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
called = False
def on_deleted(item_id: str):
nonlocal called
called = True
db.on_deleted(on_deleted)
db.set(TestModel(id = '1', name = 'Test'))
db.delete('1')
assert called
def test_sqlite_service_can_list_with_pagination():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
db.set(TestModel(id = '1', name = 'Test'))
db.set(TestModel(id = '2', name = 'Test'))
db.set(TestModel(id = '3', name = 'Test'))
results = db.list(page = 0, per_page = 2)
assert results.page == 0
assert results.pages == 2
assert results.per_page == 2
assert results.total == 3
assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test')]
def test_sqlite_service_can_list_with_pagination_and_offset():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
db.set(TestModel(id = '1', name = 'Test'))
db.set(TestModel(id = '2', name = 'Test'))
db.set(TestModel(id = '3', name = 'Test'))
results = db.list(page = 1, per_page = 2)
assert results.page == 1
assert results.pages == 2
assert results.per_page == 2
assert results.total == 3
assert results.items == [TestModel(id = '3', name = 'Test')]
def test_sqlite_service_can_search():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
db.set(TestModel(id = '1', name = 'Test'))
db.set(TestModel(id = '2', name = 'Test'))
db.set(TestModel(id = '3', name = 'Test'))
results = db.search(query = 'Test')
assert results.page == 0
assert results.pages == 1
assert results.per_page == 10
assert results.total == 3
assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test'), TestModel(id = '3', name = 'Test')]
def test_sqlite_service_can_search_with_pagination():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
db.set(TestModel(id = '1', name = 'Test'))
db.set(TestModel(id = '2', name = 'Test'))
db.set(TestModel(id = '3', name = 'Test'))
results = db.search(query = 'Test', page = 0, per_page = 2)
assert results.page == 0
assert results.pages == 2
assert results.per_page == 2
assert results.total == 3
assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test')]
def test_sqlite_service_can_search_with_pagination_and_offset():
db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id')
db.set(TestModel(id = '1', name = 'Test'))
db.set(TestModel(id = '2', name = 'Test'))
db.set(TestModel(id = '3', name = 'Test'))
results = db.search(query = 'Test', page = 1, per_page = 2)
assert results.page == 1
assert results.pages == 2
assert results.per_page == 2
assert results.total == 3
assert results.items == [TestModel(id = '3', name = 'Test')]