Merge branch 'main' into feat/return-submodels

This commit is contained in:
Lincoln Stein 2023-04-06 22:03:31 -04:00 committed by GitHub
commit f022c89249
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 964 additions and 46 deletions

View File

@ -18,6 +18,7 @@ on:
permissions: permissions:
contents: write contents: write
packages: write
jobs: jobs:
docker: docker:

View File

@ -268,7 +268,7 @@ model is so good at inpainting, a good substitute is to use the `clipseg` text
masking option: masking option:
```bash ```bash
invoke> a fluffy cat eating a hotdot invoke> a fluffy cat eating a hotdog
Outputs: Outputs:
[1010] outputs/000025.2182095108.png: a fluffy cat eating a hotdog [1010] outputs/000025.2182095108.png: a fluffy cat eating a hotdog
invoke> a smiling dog eating a hotdog -I 000025.2182095108.png -tm cat invoke> a smiling dog eating a hotdog -I 000025.2182095108.png -tm cat

View File

@ -456,7 +456,7 @@ def get_torch_source() -> (Union[str, None],str):
optional_modules = None optional_modules = None
if OS == "Linux": if OS == "Linux":
if device == "rocm": if device == "rocm":
url = "https://download.pytorch.org/whl/rocm5.2" url = "https://download.pytorch.org/whl/rocm5.4.2"
elif device == "cpu": elif device == "cpu":
url = "https://download.pytorch.org/whl/cpu" url = "https://download.pytorch.org/whl/cpu"

View File

@ -3,6 +3,8 @@
import os import os
from argparse import Namespace from argparse import Namespace
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ...backend import Globals from ...backend import Globals
from ..services.model_manager_initializer import get_model_manager from ..services.model_manager_initializer import get_model_manager
from ..services.restoration_services import RestorationServices from ..services.restoration_services import RestorationServices
@ -54,7 +56,9 @@ class ApiDependencies:
os.path.join(os.path.dirname(__file__), "../../../../outputs") os.path.join(os.path.dirname(__file__), "../../../../outputs")
) )
images = DiskImageStorage(output_folder) latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents'))
images = DiskImageStorage(f'{output_folder}/images')
# TODO: build a file/path manager? # TODO: build a file/path manager?
db_location = os.path.join(output_folder, "invokeai.db") db_location = os.path.join(output_folder, "invokeai.db")
@ -62,6 +66,7 @@ class ApiDependencies:
services = InvocationServices( services = InvocationServices(
model_manager=get_model_manager(config), model_manager=get_model_manager(config),
events=events, events=events,
latents=latents,
images=images, images=images,
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_execution_manager=SqliteItemStorage[GraphExecutionState]( graph_execution_manager=SqliteItemStorage[GraphExecutionState](

View File

@ -23,6 +23,16 @@ async def get_image(
filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name) filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
return FileResponse(filename) return FileResponse(filename)
@images_router.get("/{image_type}/thumbnails/{image_name}", operation_id="get_thumbnail")
async def get_thumbnail(
image_type: ImageType = Path(description="The type of image to get"),
image_name: str = Path(description="The name of the image to get"),
):
"""Gets a thumbnail"""
# TODO: This is not really secure at all. At least make sure only output results are served
filename = ApiDependencies.invoker.services.images.get_path(image_type, 'thumbnails/' + image_name)
return FileResponse(filename)
@images_router.post( @images_router.post(
"/uploads/", "/uploads/",

View File

@ -0,0 +1,279 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Annotated, Any, List, Literal, Optional, Union
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"])
class VaeRepo(BaseModel):
repo_id: str = Field(description="The repo ID to use for this VAE")
path: Optional[str] = Field(description="The path to the VAE")
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
class ModelInfo(BaseModel):
description: Optional[str] = Field(description="A description of the model")
class CkptModelInfo(ModelInfo):
format: Literal['ckpt'] = 'ckpt'
config: str = Field(description="The path to the model config")
weights: str = Field(description="The path to the model weights")
vae: str = Field(description="The path to the model VAE")
width: Optional[int] = Field(description="The width of the model")
height: Optional[int] = Field(description="The height of the model")
class DiffusersModelInfo(ModelInfo):
format: Literal['diffusers'] = 'diffusers'
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
path: Optional[str] = Field(description="The path to the model")
class ModelsList(BaseModel):
models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]]
@models_router.get(
"/",
operation_id="list_models",
responses={200: {"model": ModelsList }},
)
async def list_models() -> ModelsList:
"""Gets a list of models"""
models_raw = ApiDependencies.invoker.services.model_manager.list_models()
models = parse_obj_as(ModelsList, { "models": models_raw })
return models
# @socketio.on("requestSystemConfig")
# def handle_request_capabilities():
# print(">> System config requested")
# config = self.get_system_config()
# config["model_list"] = self.generate.model_manager.list_models()
# config["infill_methods"] = infill_methods()
# socketio.emit("systemConfig", config)
# @socketio.on("searchForModels")
# def handle_search_models(search_folder: str):
# try:
# if not search_folder:
# socketio.emit(
# "foundModels",
# {"search_folder": None, "found_models": None},
# )
# else:
# (
# search_folder,
# found_models,
# ) = self.generate.model_manager.search_models(search_folder)
# socketio.emit(
# "foundModels",
# {"search_folder": search_folder, "found_models": found_models},
# )
# except Exception as e:
# self.handle_exceptions(e)
# print("\n")
# @socketio.on("addNewModel")
# def handle_add_model(new_model_config: dict):
# try:
# model_name = new_model_config["name"]
# del new_model_config["name"]
# model_attributes = new_model_config
# if len(model_attributes["vae"]) == 0:
# del model_attributes["vae"]
# update = False
# current_model_list = self.generate.model_manager.list_models()
# if model_name in current_model_list:
# update = True
# print(f">> Adding New Model: {model_name}")
# self.generate.model_manager.add_model(
# model_name=model_name,
# model_attributes=model_attributes,
# clobber=True,
# )
# self.generate.model_manager.commit(opt.conf)
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "newModelAdded",
# {
# "new_model_name": model_name,
# "model_list": new_model_list,
# "update": update,
# },
# )
# print(f">> New Model Added: {model_name}")
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("deleteModel")
# def handle_delete_model(model_name: str):
# try:
# print(f">> Deleting Model: {model_name}")
# self.generate.model_manager.del_model(model_name)
# self.generate.model_manager.commit(opt.conf)
# updated_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelDeleted",
# {
# "deleted_model_name": model_name,
# "model_list": updated_model_list,
# },
# )
# print(f">> Model Deleted: {model_name}")
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("requestModelChange")
# def handle_set_model(model_name: str):
# try:
# print(f">> Model change requested: {model_name}")
# model = self.generate.set_model(model_name)
# model_list = self.generate.model_manager.list_models()
# if model is None:
# socketio.emit(
# "modelChangeFailed",
# {"model_name": model_name, "model_list": model_list},
# )
# else:
# socketio.emit(
# "modelChanged",
# {"model_name": model_name, "model_list": model_list},
# )
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("convertToDiffusers")
# def convert_to_diffusers(model_to_convert: dict):
# try:
# if model_info := self.generate.model_manager.model_info(
# model_name=model_to_convert["model_name"]
# ):
# if "weights" in model_info:
# ckpt_path = Path(model_info["weights"])
# original_config_file = Path(model_info["config"])
# model_name = model_to_convert["model_name"]
# model_description = model_info["description"]
# else:
# self.socketio.emit(
# "error", {"message": "Model is not a valid checkpoint file"}
# )
# else:
# self.socketio.emit(
# "error", {"message": "Could not retrieve model info."}
# )
# if not ckpt_path.is_absolute():
# ckpt_path = Path(Globals.root, ckpt_path)
# if original_config_file and not original_config_file.is_absolute():
# original_config_file = Path(Globals.root, original_config_file)
# diffusers_path = Path(
# ckpt_path.parent.absolute(), f"{model_name}_diffusers"
# )
# if model_to_convert["save_location"] == "root":
# diffusers_path = Path(
# global_converted_ckpts_dir(), f"{model_name}_diffusers"
# )
# if (
# model_to_convert["save_location"] == "custom"
# and model_to_convert["custom_location"] is not None
# ):
# diffusers_path = Path(
# model_to_convert["custom_location"], f"{model_name}_diffusers"
# )
# if diffusers_path.exists():
# shutil.rmtree(diffusers_path)
# self.generate.model_manager.convert_and_import(
# ckpt_path,
# diffusers_path,
# model_name=model_name,
# model_description=model_description,
# vae=None,
# original_config_file=original_config_file,
# commit_to_conf=opt.conf,
# )
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelConverted",
# {
# "new_model_name": model_name,
# "model_list": new_model_list,
# "update": True,
# },
# )
# print(f">> Model Converted: {model_name}")
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("mergeDiffusersModels")
# def merge_diffusers_models(model_merge_info: dict):
# try:
# models_to_merge = model_merge_info["models_to_merge"]
# model_ids_or_paths = [
# self.generate.model_manager.model_name_or_path(x)
# for x in models_to_merge
# ]
# merged_pipe = merge_diffusion_models(
# model_ids_or_paths,
# model_merge_info["alpha"],
# model_merge_info["interp"],
# model_merge_info["force"],
# )
# dump_path = global_models_dir() / "merged_models"
# if model_merge_info["model_merge_save_path"] is not None:
# dump_path = Path(model_merge_info["model_merge_save_path"])
# os.makedirs(dump_path, exist_ok=True)
# dump_path = dump_path / model_merge_info["merged_model_name"]
# merged_pipe.save_pretrained(dump_path, safe_serialization=1)
# merged_model_config = dict(
# model_name=model_merge_info["merged_model_name"],
# description=f'Merge of models {", ".join(models_to_merge)}',
# commit_to_conf=opt.conf,
# )
# if vae := self.generate.model_manager.config[models_to_merge[0]].get(
# "vae", None
# ):
# print(f">> Using configured VAE assigned to {models_to_merge[0]}")
# merged_model_config.update(vae=vae)
# self.generate.model_manager.import_diffuser_model(
# dump_path, **merged_model_config
# )
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelsMerged",
# {
# "merged_models": models_to_merge,
# "merged_model_name": model_merge_info["merged_model_name"],
# "model_list": new_model_list,
# "update": True,
# },
# )
# print(f">> Models Merged: {models_to_merge}")
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
# except Exception as e:
# self.handle_exceptions(e)

View File

@ -51,7 +51,7 @@ async def list_sessions(
query: str = Query(default="", description="The query string to search for"), query: str = Query(default="", description="The query string to search for"),
) -> PaginatedResults[GraphExecutionState]: ) -> PaginatedResults[GraphExecutionState]:
"""Gets a list of sessions, optionally searching""" """Gets a list of sessions, optionally searching"""
if filter == "": if query == "":
result = ApiDependencies.invoker.services.graph_execution_manager.list( result = ApiDependencies.invoker.services.graph_execution_manager.list(
page, per_page page, per_page
) )

View File

@ -14,7 +14,7 @@ from pydantic.schema import schema
from ..backend import Args from ..backend import Args
from .api.dependencies import ApiDependencies from .api.dependencies import ApiDependencies
from .api.routers import images, sessions from .api.routers import images, sessions, models
from .api.sockets import SocketIO from .api.sockets import SocketIO
from .invocations import * from .invocations import *
from .invocations.baseinvocation import BaseInvocation from .invocations.baseinvocation import BaseInvocation
@ -76,6 +76,8 @@ app.include_router(sessions.session_router, prefix="/api")
app.include_router(images.images_router, prefix="/api") app.include_router(images.images_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")
# Build a custom OpenAPI to include all outputs # Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow? # TODO: can outputs be included on metadata of invocation schemas somehow?

View File

@ -4,7 +4,8 @@ from abc import ABC, abstractmethod
import argparse import argparse
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import networkx as nx
import matplotlib.pyplot as plt
from ..invocations.image import ImageField from ..invocations.image import ImageField
from ..services.graph import GraphExecutionState from ..services.graph import GraphExecutionState
from ..services.invoker import Invoker from ..services.invoker import Invoker
@ -46,7 +47,7 @@ def add_parsers(
f"--{name}", f"--{name}",
dest=name, dest=name,
type=field_type, type=field_type,
default=field.default, default=field.default if field.default_factory is None else field.default_factory(),
choices=allowed_values, choices=allowed_values,
help=field.field_info.description, help=field.field_info.description,
) )
@ -55,7 +56,7 @@ def add_parsers(
f"--{name}", f"--{name}",
dest=name, dest=name,
type=field.type_, type=field.type_,
default=field.default, default=field.default if field.default_factory is None else field.default_factory(),
help=field.field_info.description, help=field.field_info.description,
) )
@ -200,3 +201,39 @@ class SetDefaultCommand(BaseCommand):
del context.defaults[self.field] del context.defaults[self.field]
else: else:
context.defaults[self.field] = self.value context.defaults[self.field] = self.value
class DrawGraphCommand(BaseCommand):
"""Debugs a graph"""
type: Literal['draw_graph'] = 'draw_graph'
def run(self, context: CliContext) -> None:
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
nxgraph = session.graph.nx_graph_flat()
# Draw the networkx graph
plt.figure(figsize=(20, 20))
pos = nx.spectral_layout(nxgraph)
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
nx.draw_networkx_edges(nxgraph, pos, width=2)
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
plt.axis("off")
plt.show()
class DrawExecutionGraphCommand(BaseCommand):
"""Debugs an execution graph"""
type: Literal['draw_xgraph'] = 'draw_xgraph'
def run(self, context: CliContext) -> None:
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
nxgraph = session.execution_graph.nx_graph_flat()
# Draw the networkx graph
plt.figure(figsize=(20, 20))
pos = nx.spectral_layout(nxgraph)
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
nx.draw_networkx_edges(nxgraph, pos, width=2)
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
plt.axis("off")
plt.show()

View File

@ -2,6 +2,7 @@
import argparse import argparse
import os import os
import re
import shlex import shlex
import time import time
from typing import ( from typing import (
@ -12,6 +13,8 @@ from typing import (
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.fields import Field from pydantic.fields import Field
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..backend import Args from ..backend import Args
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
from .cli.completer import set_autocompleter from .cli.completer import set_autocompleter
@ -20,7 +23,7 @@ from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase from .services.events import EventServiceBase
from .services.model_manager_initializer import get_model_manager from .services.model_manager_initializer import get_model_manager
from .services.restoration_services import RestorationServices from .services.restoration_services import RestorationServices
from .services.graph import Edge, EdgeConnection, GraphExecutionState from .services.graph import Edge, EdgeConnection, GraphExecutionState, are_connection_types_compatible
from .services.image_storage import DiskImageStorage from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices from .services.invocation_services import InvocationServices
@ -44,7 +47,7 @@ def add_invocation_args(command_parser):
"-l", "-l",
action="append", action="append",
nargs=3, nargs=3,
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)", help="A link in the format 'source_node source_field dest_field'. source_node can be relative to history (e.g. -1)",
) )
command_parser.add_argument( command_parser.add_argument(
@ -94,6 +97,9 @@ def generate_matching_edges(
invalid_fields = set(["type", "id"]) invalid_fields = set(["type", "id"])
matching_fields = matching_fields.difference(invalid_fields) matching_fields = matching_fields.difference(invalid_fields)
# Validate types
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f], bfields[f])]
edges = [ edges = [
Edge( Edge(
source=EdgeConnection(node_id=a.id, field=field), source=EdgeConnection(node_id=a.id, field=field),
@ -149,7 +155,8 @@ def invoke_cli():
services = InvocationServices( services = InvocationServices(
model_manager=model_manager, model_manager=model_manager,
events=events, events=events,
images=DiskImageStorage(output_folder), latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
images=DiskImageStorage(f'{output_folder}/images'),
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_execution_manager=SqliteItemStorage[GraphExecutionState]( graph_execution_manager=SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions" filename=db_location, table_name="graph_executions"
@ -162,6 +169,8 @@ def invoke_cli():
session: GraphExecutionState = invoker.create_execution_state() session: GraphExecutionState = invoker.create_execution_state()
parser = get_command_parser() parser = get_command_parser()
re_negid = re.compile('^-[0-9]+$')
# Uncomment to print out previous sessions at startup # Uncomment to print out previous sessions at startup
# print(services.session_manager.list()) # print(services.session_manager.list())
@ -227,7 +236,11 @@ def invoke_cli():
# Parse provided links # Parse provided links
if "link_node" in args and args["link_node"]: if "link_node" in args and args["link_node"]:
for link in args["link_node"]: for link in args["link_node"]:
link_node = context.session.graph.get_node(link) node_id = link
if re_negid.match(node_id):
node_id = str(current_id + int(node_id))
link_node = context.session.graph.get_node(node_id)
matching_edges = generate_matching_edges( matching_edges = generate_matching_edges(
link_node, command.command link_node, command.command
) )
@ -237,10 +250,15 @@ def invoke_cli():
if "link" in args and args["link"]: if "link" in args and args["link"]:
for link in args["link"]: for link in args["link"]:
edges = [e for e in edges if e.destination.node_id != command.command.id and e.destination.field != link[2]] edges = [e for e in edges if e.destination.node_id != command.command.id or e.destination.field != link[2]]
node_id = link[0]
if re_negid.match(node_id):
node_id = str(current_id + int(node_id))
edges.append( edges.append(
Edge( Edge(
source=EdgeConnection(node_id=link[1], field=link[0]), source=EdgeConnection(node_id=node_id, field=link[1]),
destination=EdgeConnection( destination=EdgeConnection(
node_id=command.command.id, field=link[2] node_id=command.command.id, field=link[2]
) )

View File

@ -0,0 +1,50 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal
import cv2 as cv
import numpy as np
import numpy.random
from PIL import Image, ImageOps
from pydantic import Field
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext, BaseInvocationOutput
from .image import ImageField, ImageOutput
class IntCollectionOutput(BaseInvocationOutput):
"""A collection of integers"""
type: Literal["int_collection"] = "int_collection"
# Outputs
collection: list[int] = Field(default=[], description="The int collection")
class RangeInvocation(BaseInvocation):
"""Creates a range"""
type: Literal["range"] = "range"
# Inputs
start: int = Field(default=0, description="The start of the range")
stop: int = Field(default=10, description="The stop of the range")
step: int = Field(default=1, description="The step of the range")
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
class RandomRangeInvocation(BaseInvocation):
"""Creates a collection of random numbers"""
type: Literal["random_range"] = "random_range"
# Inputs
low: int = Field(default=0, description="The inclusive low value")
high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
size: int = Field(default=1, description="The number of values to generate")
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput(collection=list(numpy.random.randint(self.low, self.high, size=self.size)))

View File

@ -0,0 +1,321 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Optional
from pydantic import BaseModel, Field
from torch import Tensor
import torch
from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import CUDA_DEVICE, torch_dtype
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.image_util.seamless import configure_model_padding
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
import numpy as np
from accelerate.utils import set_seed
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from ...backend.generator import Generator
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util.util import image_to_dataURL
from diffusers.schedulers import SchedulerMixin as Scheduler
import diffusers
from diffusers import DiffusionPipeline
class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations"""
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
class LatentsOutput(BaseInvocationOutput):
"""Base class for invocations that output latents"""
#fmt: off
type: Literal["latent_output"] = "latent_output"
latents: LatentsField = Field(default=None, description="The output latents")
#fmt: on
class NoiseOutput(BaseInvocationOutput):
"""Invocation noise output"""
#fmt: off
type: Literal["noise_output"] = "noise_output"
noise: LatentsField = Field(default=None, description="The output noise")
#fmt: on
# TODO: this seems like a hack
scheduler_map = dict(
ddim=diffusers.DDIMScheduler,
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_euler=diffusers.EulerDiscreteScheduler,
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
k_heun=diffusers.HeunDiscreteScheduler,
k_lms=diffusers.LMSDiscreteScheduler,
plms=diffusers.PNDMScheduler,
)
SAMPLER_NAME_VALUES = Literal[
tuple(list(scheduler_map.keys()))
]
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class = scheduler_map.get(scheduler_name,'ddim')
scheduler = scheduler_class.from_config(model.scheduler.config)
# hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False
return scheduler
def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_channels:int=4, use_mps_noise:bool=False, downsampling_factor:int = 8):
# limit noise to only the diffusion image channels, not the mask channels
input_channels = min(latent_channels, 4)
use_device = "cpu" if (use_mps_noise or device.type == "mps") else device
generator = torch.Generator(device=use_device).manual_seed(seed)
x = torch.randn(
[
1,
input_channels,
height // downsampling_factor,
width // downsampling_factor,
],
dtype=torch_dtype(device),
device=use_device,
generator=generator,
).to(device)
# if self.perlin > 0.0:
# perlin_noise = self.get_perlin_noise(
# width // self.downsampling_factor, height // self.downsampling_factor
# )
# x = (1 - self.perlin) * x + self.perlin * perlin_noise
return x
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""
type: Literal["noise"] = "noise"
# Inputs
seed: int = Field(default=0, ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", )
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", )
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
def invoke(self, context: InvocationContext) -> NoiseOutput:
device = torch.device(CUDA_DEVICE)
noise = get_noise(self.width, self.height, device, self.seed)
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, noise)
return NoiseOutput(
noise=LatentsField(latents_name=name)
)
# Text to image
class TextToLatentsInvocation(BaseInvocation):
"""Generates latents from a prompt."""
type: Literal["t2l"] = "t2l"
# Inputs
# TODO: consider making prompt optional to enable providing prompt through a link
# fmt: off
prompt: Optional[str] = Field(description="The prompt to generate an image from")
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use" )
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
# fmt: on
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, sample: Tensor, step: int
) -> None:
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
self.id,
{
"width": width,
"height": height,
"dataURL": dataURL
},
step,
self.steps,
)
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
model_info = model_manager.get_model(self.model)
model_name = model_info['model_name']
model_hash = model_info['hash']
model: StableDiffusionGeneratorPipeline = model_info['model']
model.scheduler = get_scheduler(
model=model,
scheduler_name=self.sampler_name
)
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
self.seamless,
self.seamless_axes
)
else:
configure_model_padding(model,
self.seamless,
self.seamless_axes
)
return model
def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model)
conditioning_data = ConditioningData(
uc,
c,
self.cfg_scale,
extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=0.0,#threshold,
warmup=0.2,#warmup,
h_symmetry_time_pct=None,#h_symmetry_time_pct,
v_symmetry_time_pct=None#v_symmetry_time_pct,
),
).add_scheduler_args_if_applicable(model.scheduler, eta=None)#ddim_eta)
return conditioning_data
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state.latents, state.step)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model)
# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
callback=step_callback
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
return LatentsOutput(
latents=LatentsField(latents_name=name)
)
class LatentsToLatentsInvocation(TextToLatentsInvocation):
"""Generates latents using latents as base image."""
type: Literal["l2l"] = "l2l"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.5, description="The strength of the latents to use")
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name)
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state.latents, state.step)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model)
# TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=model.device, dtype=latent.dtype
)
timesteps, _ = model.get_img2img_timesteps(
self.steps,
self.strength,
device=model.device,
)
result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=initial_latents,
timesteps=timesteps,
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
callback=step_callback
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
return LatentsOutput(
latents=LatentsField(latents_name=name)
)
# Latent to image
class LatentsToImageInvocation(BaseInvocation):
"""Generates an image from latents."""
type: Literal["l2i"] = "l2i"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
model: str = Field(default="", description="The model to use")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.services.latents.get(self.latents.latents_name)
# TODO: this only really needs the vae
model_info = context.services.model_manager.get_model(self.model)
model: StableDiffusionGeneratorPipeline = model_info['model']
with torch.inference_mode():
np_image = model.decode_latents(latents)
image = model.numpy_to_pil(np_image)[0]
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)

View File

@ -0,0 +1,68 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
from typing import Literal, Optional
import numpy
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field
from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
class IntOutput(BaseInvocationOutput):
"""An integer output"""
#fmt: off
type: Literal["int_output"] = "int_output"
a: int = Field(default=None, description="The output integer")
#fmt: on
class AddInvocation(BaseInvocation):
"""Adds two numbers"""
#fmt: off
type: Literal["add"] = "add"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a + self.b)
class SubtractInvocation(BaseInvocation):
"""Subtracts two numbers"""
#fmt: off
type: Literal["sub"] = "sub"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a - self.b)
class MultiplyInvocation(BaseInvocation):
"""Multiplies two numbers"""
#fmt: off
type: Literal["mul"] = "mul"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a * self.b)
class DivideInvocation(BaseInvocation):
"""Divides two numbers"""
#fmt: off
type: Literal["div"] = "div"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=int(self.a / self.b))

View File

@ -1069,9 +1069,8 @@ class GraphExecutionState(BaseModel):
n n
for n in prepared_nodes for n in prepared_nodes
if all( if all(
pit nx.has_path(execution_graph, pit[0], n)
for pit in parent_iterators for pit in parent_iterators
if nx.has_path(execution_graph, pit[0], n)
) )
), ),
None, None,

View File

@ -9,6 +9,7 @@ from queue import Queue
from typing import Dict from typing import Dict
from PIL.Image import Image from PIL.Image import Image
from invokeai.app.util.save_thumbnail import save_thumbnail
from invokeai.backend.image_util import PngWriter from invokeai.backend.image_util import PngWriter
@ -66,6 +67,9 @@ class DiskImageStorage(ImageStorageBase):
Path(os.path.join(output_folder, image_type)).mkdir( Path(os.path.join(output_folder, image_type)).mkdir(
parents=True, exist_ok=True parents=True, exist_ok=True
) )
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir(
parents=True, exist_ok=True
)
def get(self, image_type: ImageType, image_name: str) -> Image: def get(self, image_type: ImageType, image_name: str) -> Image:
image_path = self.get_path(image_type, image_name) image_path = self.get_path(image_type, image_name)
@ -87,7 +91,11 @@ class DiskImageStorage(ImageStorageBase):
self.__pngWriter.save_image_and_prompt_to_png( self.__pngWriter.save_image_and_prompt_to_png(
image, "", image_subpath, None image, "", image_subpath, None
) # TODO: just pass full path to png writer ) # TODO: just pass full path to png writer
save_thumbnail(
image=image,
filename=image_name,
path=os.path.join(self.__output_folder, image_type, "thumbnails"),
)
image_path = self.get_path(image_type, image_name) image_path = self.get_path(image_type, image_name)
self.__set_cache(image_path, image) self.__set_cache(image_path, image)

View File

@ -2,6 +2,7 @@
from invokeai.backend import ModelManager from invokeai.backend import ModelManager
from .events import EventServiceBase from .events import EventServiceBase
from .latent_storage import LatentsStorageBase
from .image_storage import ImageStorageBase from .image_storage import ImageStorageBase
from .restoration_services import RestorationServices from .restoration_services import RestorationServices
from .invocation_queue import InvocationQueueABC from .invocation_queue import InvocationQueueABC
@ -11,6 +12,7 @@ class InvocationServices:
"""Services that can be used by invocations""" """Services that can be used by invocations"""
events: EventServiceBase events: EventServiceBase
latents: LatentsStorageBase
images: ImageStorageBase images: ImageStorageBase
queue: InvocationQueueABC queue: InvocationQueueABC
model_manager: ModelManager model_manager: ModelManager
@ -24,6 +26,7 @@ class InvocationServices:
self, self,
model_manager: ModelManager, model_manager: ModelManager,
events: EventServiceBase, events: EventServiceBase,
latents: LatentsStorageBase,
images: ImageStorageBase, images: ImageStorageBase,
queue: InvocationQueueABC, queue: InvocationQueueABC,
graph_execution_manager: ItemStorageABC["GraphExecutionState"], graph_execution_manager: ItemStorageABC["GraphExecutionState"],
@ -32,6 +35,7 @@ class InvocationServices:
): ):
self.model_manager = model_manager self.model_manager = model_manager
self.events = events self.events = events
self.latents = latents
self.images = images self.images = images
self.queue = queue self.queue = queue
self.graph_execution_manager = graph_execution_manager self.graph_execution_manager = graph_execution_manager

View File

@ -33,7 +33,6 @@ class Invoker:
self.services.graph_execution_manager.set(graph_execution_state) self.services.graph_execution_manager.set(graph_execution_state)
# Queue the invocation # Queue the invocation
print(f"queueing item {invocation.id}")
self.services.queue.put( self.services.queue.put(
InvocationQueueItem( InvocationQueueItem(
# session_id = session.id, # session_id = session.id,

View File

@ -0,0 +1,93 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import os
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
from typing import Dict
import torch
class LatentsStorageBase(ABC):
"""Responsible for storing and retrieving latents."""
@abstractmethod
def get(self, name: str) -> torch.Tensor:
pass
@abstractmethod
def set(self, name: str, data: torch.Tensor) -> None:
pass
@abstractmethod
def delete(self, name: str) -> None:
pass
class ForwardCacheLatentsStorage(LatentsStorageBase):
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
__cache: Dict[str, torch.Tensor]
__cache_ids: Queue
__max_cache_size: int
__underlying_storage: LatentsStorageBase
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
self.__underlying_storage = underlying_storage
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = max_cache_size
def get(self, name: str) -> torch.Tensor:
cache_item = self.__get_cache(name)
if cache_item is not None:
return cache_item
latent = self.__underlying_storage.get(name)
self.__set_cache(name, latent)
return latent
def set(self, name: str, data: torch.Tensor) -> None:
self.__underlying_storage.set(name, data)
self.__set_cache(name, data)
def delete(self, name: str) -> None:
self.__underlying_storage.delete(name)
if name in self.__cache:
del self.__cache[name]
def __get_cache(self, name: str) -> torch.Tensor|None:
return None if name not in self.__cache else self.__cache[name]
def __set_cache(self, name: str, data: torch.Tensor):
if not name in self.__cache:
self.__cache[name] = data
self.__cache_ids.put(name)
if self.__cache_ids.qsize() > self.__max_cache_size:
self.__cache.pop(self.__cache_ids.get())
class DiskLatentsStorage(LatentsStorageBase):
"""Stores latents in a folder on disk without caching"""
__output_folder: str
def __init__(self, output_folder: str):
self.__output_folder = output_folder
Path(output_folder).mkdir(parents=True, exist_ok=True)
def get(self, name: str) -> torch.Tensor:
latent_path = self.get_path(name)
return torch.load(latent_path)
def set(self, name: str, data: torch.Tensor) -> None:
latent_path = self.get_path(name)
torch.save(data, latent_path)
def delete(self, name: str) -> None:
latent_path = self.get_path(name)
os.remove(latent_path)
def get_path(self, name: str) -> str:
return os.path.join(self.__output_folder, name)

View File

@ -59,6 +59,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""", f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
(item.json(),), (item.json(),),
) )
self._conn.commit()
finally: finally:
self._lock.release() self._lock.release()
self._on_changed(item) self._on_changed(item)
@ -84,6 +85,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._cursor.execute( self._cursor.execute(
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),) f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
) )
self._conn.commit()
finally: finally:
self._lock.release() self._lock.release()
self._on_deleted(id) self._on_deleted(id)

View File

@ -0,0 +1,25 @@
import os
from PIL import Image
def save_thumbnail(
image: Image.Image,
filename: str,
path: str,
size: int = 256,
) -> str:
"""
Saves a thumbnail of an image, returning its path.
"""
base_filename = os.path.splitext(filename)[0]
thumbnail_path = os.path.join(path, base_filename + ".webp")
if os.path.exists(thumbnail_path):
return thumbnail_path
image_copy = image.copy()
image_copy.thumbnail(size=(size, size))
image_copy.save(thumbnail_path, "WEBP")
return thumbnail_path

View File

@ -581,6 +581,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
elif command.startswith("!replay"): elif command.startswith("!replay"):
file_path = command.replace("!replay", "", 1).strip() file_path = command.replace("!replay", "", 1).strip()
file_path = os.path.join(opt.outdir, file_path)
if infile is None and os.path.isfile(file_path): if infile is None and os.path.isfile(file_path):
infile = open(file_path, "r", encoding="utf-8") infile = open(file_path, "r", encoding="utf-8")
completer.add_history(command) completer.add_history(command)

View File

@ -38,7 +38,7 @@ dependencies = [
"albumentations", "albumentations",
"click", "click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel==1.0.4", "compel==1.0.5",
"datasets", "datasets",
"diffusers[torch]~=0.14", "diffusers[torch]~=0.14",
"dnspython==2.2.1", "dnspython==2.2.1",

View File

@ -1,6 +1,8 @@
from .test_invoker import create_edge from .test_invoker import create_edge
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from invokeai.app.invocations.collections import RangeInvocation
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
from invokeai.app.services.processor import DefaultInvocationProcessor from invokeai.app.services.processor import DefaultInvocationProcessor
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from invokeai.app.services.invocation_queue import MemoryInvocationQueue from invokeai.app.services.invocation_queue import MemoryInvocationQueue
@ -21,13 +23,14 @@ def simple_graph():
def mock_services(): def mock_services():
# NOTE: none of these are actually called by the test invocations # NOTE: none of these are actually called by the test invocations
return InvocationServices( return InvocationServices(
model_manager = None, model_manager = None, # type: ignore
events = None, events = None, # type: ignore
images = None, images = None, # type: ignore
latents = None, # type: ignore
queue = MemoryInvocationQueue(), queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(), processor = DefaultInvocationProcessor(),
restoration = None, restoration = None, # type: ignore
) )
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]: def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
@ -73,31 +76,23 @@ def test_graph_is_not_complete(simple_graph, mock_services):
def test_graph_state_expands_iterator(mock_services): def test_graph_state_expands_iterator(mock_services):
graph = Graph() graph = Graph()
test_prompts = ["Banana sushi", "Cat sushi"] graph.add_node(RangeInvocation(id = "0", start = 0, stop = 3, step = 1))
graph.add_node(PromptCollectionTestInvocation(id = "1", collection = list(test_prompts))) graph.add_node(IterateInvocation(id = "1"))
graph.add_node(IterateInvocation(id = "2")) graph.add_node(MultiplyInvocation(id = "2", b = 10))
graph.add_node(ImageTestInvocation(id = "3")) graph.add_node(AddInvocation(id = "3", b = 1))
graph.add_edge(create_edge("1", "collection", "2", "collection")) graph.add_edge(create_edge("0", "collection", "1", "collection"))
graph.add_edge(create_edge("2", "item", "3", "prompt")) graph.add_edge(create_edge("1", "item", "2", "a"))
graph.add_edge(create_edge("2", "a", "3", "a"))
g = GraphExecutionState(graph = graph) g = GraphExecutionState(graph = graph)
n1 = invoke_next(g, mock_services) while not g.is_complete():
n2 = invoke_next(g, mock_services) invoke_next(g, mock_services)
n3 = invoke_next(g, mock_services)
n4 = invoke_next(g, mock_services) prepared_add_nodes = g.source_prepared_mapping['3']
n5 = invoke_next(g, mock_services) results = set([g.results[n].a for n in prepared_add_nodes])
expected = set([1, 11, 21])
assert results == expected
assert g.prepared_source_mapping[n1[0].id] == "1"
assert g.prepared_source_mapping[n2[0].id] == "2"
assert g.prepared_source_mapping[n3[0].id] == "2"
assert g.prepared_source_mapping[n4[0].id] == "3"
assert g.prepared_source_mapping[n5[0].id] == "3"
assert isinstance(n4[0], ImageTestInvocation)
assert isinstance(n5[0], ImageTestInvocation)
prompts = [n4[0].prompt, n5[0].prompt]
assert sorted(prompts) == sorted(test_prompts)
def test_graph_state_collects(mock_services): def test_graph_state_collects(mock_services):
graph = Graph() graph = Graph()

View File

@ -24,10 +24,11 @@ def mock_services() -> InvocationServices:
model_manager = None, # type: ignore model_manager = None, # type: ignore
events = TestEventService(), events = TestEventService(),
images = None, # type: ignore images = None, # type: ignore
latents = None, # type: ignore
queue = MemoryInvocationQueue(), queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(), processor = DefaultInvocationProcessor(),
restoration = None, restoration = None, # type: ignore
) )
@pytest.fixture() @pytest.fixture()