InvokeAI/tests/nodes/test_nodes.py
Kyle Schouviller 34e3aa1f88 parent 9eed1919c2
author Kyle Schouviller <kyle0654@hotmail.com> 1669872800 -0800
committer Kyle Schouviller <kyle0654@hotmail.com> 1676240900 -0800

Adding base node architecture

Fix type annotation errors

Runs and generates, but breaks in saving session

Fix default model value setting. Fix deprecation warning.

Fixed node api

Adding markdown docs

Simplifying Generate construction in apps

[nodes] A few minor changes (#2510)

* Pin api-related requirements

* Remove confusing extra CORS origins list

* Adds response models for HTTP 200

[nodes] Adding graph_execution_state to soon replace session. Adding tests with pytest.

Minor typing fixes

[nodes] Fix some small output query hookups

[node] Fixing some additional typing issues

[nodes] Move and expand graph code. Add base item storage and sqlite implementation.

Update startup to match new code

[nodes] Add callbacks to item storage

[nodes] Adding an InvocationContext object to use for invocations to provide easier extensibility

[nodes] New execution model that handles iteration

[nodes] Fixing the CLI

[nodes] Adding a note to the CLI

[nodes] Split processing thread into separate service

[node] Add error message on node processing failure

Removing old files and duplicated packages

Adding python-multipart
2023-02-24 18:57:02 -08:00

92 lines
3.3 KiB
Python

from typing import Any, Callable, Literal
from ldm.invoke.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from ldm.invoke.app.invocations.image import ImageField
from ldm.invoke.app.services.invocation_services import InvocationServices
from pydantic import Field
import pytest
# Define test invocations before importing anything that uses invocations
class ListPassThroughInvocationOutput(BaseInvocationOutput):
type: Literal['test_list_output'] = 'test_list_output'
collection: list[ImageField] = Field(default_factory=list)
class ListPassThroughInvocation(BaseInvocation):
type: Literal['test_list'] = 'test_list'
collection: list[ImageField] = Field(default_factory=list)
def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput:
return ListPassThroughInvocationOutput(collection = self.collection)
class PromptTestInvocationOutput(BaseInvocationOutput):
type: Literal['test_prompt_output'] = 'test_prompt_output'
prompt: str = Field(default = "")
class PromptTestInvocation(BaseInvocation):
type: Literal['test_prompt'] = 'test_prompt'
prompt: str = Field(default = "")
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
return PromptTestInvocationOutput(prompt = self.prompt)
class ImageTestInvocationOutput(BaseInvocationOutput):
type: Literal['test_image_output'] = 'test_image_output'
image: ImageField = Field()
class ImageTestInvocation(BaseInvocation):
type: Literal['test_image'] = 'test_image'
prompt: str = Field(default = "")
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
type: Literal['test_prompt_collection_output'] = 'test_prompt_collection_output'
collection: list[str] = Field(default_factory=list)
class PromptCollectionTestInvocation(BaseInvocation):
type: Literal['test_prompt_collection'] = 'test_prompt_collection'
collection: list[str] = Field()
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
from ldm.invoke.app.services.events import EventServiceBase
from ldm.invoke.app.services.graph import EdgeConnection
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]:
return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field))
class TestEvent:
event_name: str
payload: Any
def __init__(self, event_name: str, payload: Any):
self.event_name = event_name
self.payload = payload
class TestEventService(EventServiceBase):
events: list
def __init__(self):
super().__init__()
self.events = list()
def dispatch(self, event_name: str, payload: Any) -> None:
pass
def wait_until(condition: Callable[[], bool], timeout: int = 10, interval: float = 0.1) -> None:
import time
start_time = time.time()
while time.time() - start_time < timeout:
if condition():
return
time.sleep(interval)
raise TimeoutError("Condition not met")