diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 347fba7e97..5698d25758 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -3,6 +3,8 @@ import os from argparse import Namespace +from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage + from ...backend import Globals from ..services.model_manager_initializer import get_model_manager from ..services.restoration_services import RestorationServices @@ -54,7 +56,9 @@ class ApiDependencies: os.path.join(os.path.dirname(__file__), "../../../../outputs") ) - images = DiskImageStorage(output_folder) + latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')) + + images = DiskImageStorage(f'{output_folder}/images') # TODO: build a file/path manager? db_location = os.path.join(output_folder, "invokeai.db") @@ -62,6 +66,7 @@ class ApiDependencies: services = InvocationServices( model_manager=get_model_manager(config), events=events, + latents=latents, images=images, queue=MemoryInvocationQueue(), graph_execution_manager=SqliteItemStorage[GraphExecutionState]( diff --git a/invokeai/app/cli/commands.py b/invokeai/app/cli/commands.py index 21e65291e9..5f4da73303 100644 --- a/invokeai/app/cli/commands.py +++ b/invokeai/app/cli/commands.py @@ -4,7 +4,8 @@ from abc import ABC, abstractmethod import argparse from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints from pydantic import BaseModel, Field - +import networkx as nx +import matplotlib.pyplot as plt from ..invocations.image import ImageField from ..services.graph import GraphExecutionState from ..services.invoker import Invoker @@ -46,7 +47,7 @@ def add_parsers( f"--{name}", dest=name, type=field_type, - default=field.default, + default=field.default if field.default_factory is None else field.default_factory(), choices=allowed_values, help=field.field_info.description, ) @@ -55,7 +56,7 @@ def add_parsers( f"--{name}", dest=name, type=field.type_, - default=field.default, + default=field.default if field.default_factory is None else field.default_factory(), help=field.field_info.description, ) @@ -200,3 +201,39 @@ class SetDefaultCommand(BaseCommand): del context.defaults[self.field] else: context.defaults[self.field] = self.value + + +class DrawGraphCommand(BaseCommand): + """Debugs a graph""" + type: Literal['draw_graph'] = 'draw_graph' + + def run(self, context: CliContext) -> None: + session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id) + nxgraph = session.graph.nx_graph_flat() + + # Draw the networkx graph + plt.figure(figsize=(20, 20)) + pos = nx.spectral_layout(nxgraph) + nx.draw_networkx_nodes(nxgraph, pos, node_size=1000) + nx.draw_networkx_edges(nxgraph, pos, width=2) + nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif") + plt.axis("off") + plt.show() + + +class DrawExecutionGraphCommand(BaseCommand): + """Debugs an execution graph""" + type: Literal['draw_xgraph'] = 'draw_xgraph' + + def run(self, context: CliContext) -> None: + session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id) + nxgraph = session.execution_graph.nx_graph_flat() + + # Draw the networkx graph + plt.figure(figsize=(20, 20)) + pos = nx.spectral_layout(nxgraph) + nx.draw_networkx_nodes(nxgraph, pos, node_size=1000) + nx.draw_networkx_edges(nxgraph, pos, width=2) + nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif") + plt.axis("off") + plt.show() diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index bf003e5cb1..a257825dcc 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -2,6 +2,7 @@ import argparse import os +import re import shlex import time from typing import ( @@ -12,6 +13,8 @@ from typing import ( from pydantic import BaseModel from pydantic.fields import Field +from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage + from ..backend import Args from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history from .cli.completer import set_autocompleter @@ -20,7 +23,7 @@ from .invocations.baseinvocation import BaseInvocation from .services.events import EventServiceBase from .services.model_manager_initializer import get_model_manager from .services.restoration_services import RestorationServices -from .services.graph import Edge, EdgeConnection, GraphExecutionState +from .services.graph import Edge, EdgeConnection, GraphExecutionState, are_connection_types_compatible from .services.image_storage import DiskImageStorage from .services.invocation_queue import MemoryInvocationQueue from .services.invocation_services import InvocationServices @@ -44,7 +47,7 @@ def add_invocation_args(command_parser): "-l", action="append", nargs=3, - help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)", + help="A link in the format 'source_node source_field dest_field'. source_node can be relative to history (e.g. -1)", ) command_parser.add_argument( @@ -94,6 +97,9 @@ def generate_matching_edges( invalid_fields = set(["type", "id"]) matching_fields = matching_fields.difference(invalid_fields) + # Validate types + matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f], bfields[f])] + edges = [ Edge( source=EdgeConnection(node_id=a.id, field=field), @@ -149,7 +155,8 @@ def invoke_cli(): services = InvocationServices( model_manager=model_manager, events=events, - images=DiskImageStorage(output_folder), + latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')), + images=DiskImageStorage(f'{output_folder}/images'), queue=MemoryInvocationQueue(), graph_execution_manager=SqliteItemStorage[GraphExecutionState]( filename=db_location, table_name="graph_executions" @@ -162,6 +169,8 @@ def invoke_cli(): session: GraphExecutionState = invoker.create_execution_state() parser = get_command_parser() + re_negid = re.compile('^-[0-9]+$') + # Uncomment to print out previous sessions at startup # print(services.session_manager.list()) @@ -227,7 +236,11 @@ def invoke_cli(): # Parse provided links if "link_node" in args and args["link_node"]: for link in args["link_node"]: - link_node = context.session.graph.get_node(link) + node_id = link + if re_negid.match(node_id): + node_id = str(current_id + int(node_id)) + + link_node = context.session.graph.get_node(node_id) matching_edges = generate_matching_edges( link_node, command.command ) @@ -237,10 +250,15 @@ def invoke_cli(): if "link" in args and args["link"]: for link in args["link"]: - edges = [e for e in edges if e.destination.node_id != command.command.id and e.destination.field != link[2]] + edges = [e for e in edges if e.destination.node_id != command.command.id or e.destination.field != link[2]] + + node_id = link[0] + if re_negid.match(node_id): + node_id = str(current_id + int(node_id)) + edges.append( Edge( - source=EdgeConnection(node_id=link[1], field=link[0]), + source=EdgeConnection(node_id=node_id, field=link[1]), destination=EdgeConnection( node_id=command.command.id, field=link[2] ) diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py new file mode 100644 index 0000000000..c68b7449cc --- /dev/null +++ b/invokeai/app/invocations/collections.py @@ -0,0 +1,50 @@ +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) + +from typing import Literal + +import cv2 as cv +import numpy as np +import numpy.random +from PIL import Image, ImageOps +from pydantic import Field + +from ..services.image_storage import ImageType +from .baseinvocation import BaseInvocation, InvocationContext, BaseInvocationOutput +from .image import ImageField, ImageOutput + + +class IntCollectionOutput(BaseInvocationOutput): + """A collection of integers""" + + type: Literal["int_collection"] = "int_collection" + + # Outputs + collection: list[int] = Field(default=[], description="The int collection") + + +class RangeInvocation(BaseInvocation): + """Creates a range""" + + type: Literal["range"] = "range" + + # Inputs + start: int = Field(default=0, description="The start of the range") + stop: int = Field(default=10, description="The stop of the range") + step: int = Field(default=1, description="The step of the range") + + def invoke(self, context: InvocationContext) -> IntCollectionOutput: + return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step))) + + +class RandomRangeInvocation(BaseInvocation): + """Creates a collection of random numbers""" + + type: Literal["random_range"] = "random_range" + + # Inputs + low: int = Field(default=0, description="The inclusive low value") + high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value") + size: int = Field(default=1, description="The number of values to generate") + + def invoke(self, context: InvocationContext) -> IntCollectionOutput: + return IntCollectionOutput(collection=list(numpy.random.randint(self.low, self.high, size=self.size))) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py new file mode 100644 index 0000000000..0481282ba9 --- /dev/null +++ b/invokeai/app/invocations/latent.py @@ -0,0 +1,321 @@ +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) + +from typing import Literal, Optional +from pydantic import BaseModel, Field +from torch import Tensor +import torch + +from ...backend.model_management.model_manager import ModelManager +from ...backend.util.devices import CUDA_DEVICE, torch_dtype +from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings +from ...backend.image_util.seamless import configure_model_padding +from ...backend.prompting.conditioning import get_uc_and_c_and_ec +from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline +from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext +import numpy as np +from accelerate.utils import set_seed +from ..services.image_storage import ImageType +from .baseinvocation import BaseInvocation, InvocationContext +from .image import ImageField, ImageOutput +from ...backend.generator import Generator +from ...backend.stable_diffusion import PipelineIntermediateState +from ...backend.util.util import image_to_dataURL +from diffusers.schedulers import SchedulerMixin as Scheduler +import diffusers +from diffusers import DiffusionPipeline + + +class LatentsField(BaseModel): + """A latents field used for passing latents between invocations""" + + latents_name: Optional[str] = Field(default=None, description="The name of the latents") + + +class LatentsOutput(BaseInvocationOutput): + """Base class for invocations that output latents""" + #fmt: off + type: Literal["latent_output"] = "latent_output" + latents: LatentsField = Field(default=None, description="The output latents") + #fmt: on + +class NoiseOutput(BaseInvocationOutput): + """Invocation noise output""" + #fmt: off + type: Literal["noise_output"] = "noise_output" + noise: LatentsField = Field(default=None, description="The output noise") + #fmt: on + + +# TODO: this seems like a hack +scheduler_map = dict( + ddim=diffusers.DDIMScheduler, + dpmpp_2=diffusers.DPMSolverMultistepScheduler, + k_dpm_2=diffusers.KDPM2DiscreteScheduler, + k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler, + k_dpmpp_2=diffusers.DPMSolverMultistepScheduler, + k_euler=diffusers.EulerDiscreteScheduler, + k_euler_a=diffusers.EulerAncestralDiscreteScheduler, + k_heun=diffusers.HeunDiscreteScheduler, + k_lms=diffusers.LMSDiscreteScheduler, + plms=diffusers.PNDMScheduler, +) + + +SAMPLER_NAME_VALUES = Literal[ + tuple(list(scheduler_map.keys())) +] + + +def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler: + scheduler_class = scheduler_map.get(scheduler_name,'ddim') + scheduler = scheduler_class.from_config(model.scheduler.config) + # hack copied over from generate.py + if not hasattr(scheduler, 'uses_inpainting_model'): + scheduler.uses_inpainting_model = lambda: False + return scheduler + + +def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_channels:int=4, use_mps_noise:bool=False, downsampling_factor:int = 8): + # limit noise to only the diffusion image channels, not the mask channels + input_channels = min(latent_channels, 4) + use_device = "cpu" if (use_mps_noise or device.type == "mps") else device + generator = torch.Generator(device=use_device).manual_seed(seed) + x = torch.randn( + [ + 1, + input_channels, + height // downsampling_factor, + width // downsampling_factor, + ], + dtype=torch_dtype(device), + device=use_device, + generator=generator, + ).to(device) + # if self.perlin > 0.0: + # perlin_noise = self.get_perlin_noise( + # width // self.downsampling_factor, height // self.downsampling_factor + # ) + # x = (1 - self.perlin) * x + self.perlin * perlin_noise + return x + + +class NoiseInvocation(BaseInvocation): + """Generates latent noise.""" + + type: Literal["noise"] = "noise" + + # Inputs + seed: int = Field(default=0, ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", ) + width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", ) + height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", ) + + def invoke(self, context: InvocationContext) -> NoiseOutput: + device = torch.device(CUDA_DEVICE) + noise = get_noise(self.width, self.height, device, self.seed) + + name = f'{context.graph_execution_state_id}__{self.id}' + context.services.latents.set(name, noise) + return NoiseOutput( + noise=LatentsField(latents_name=name) + ) + + +# Text to image +class TextToLatentsInvocation(BaseInvocation): + """Generates latents from a prompt.""" + + type: Literal["t2l"] = "t2l" + + # Inputs + # TODO: consider making prompt optional to enable providing prompt through a link + # fmt: off + prompt: Optional[str] = Field(description="The prompt to generate an image from") + seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", ) + noise: Optional[LatentsField] = Field(description="The noise to use") + steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") + width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", ) + height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", ) + cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) + sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use" ) + seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) + seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") + model: str = Field(default="", description="The model to use (currently ignored)") + progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", ) + # fmt: on + + # TODO: pass this an emitter method or something? or a session for dispatching? + def dispatch_progress( + self, context: InvocationContext, sample: Tensor, step: int + ) -> None: + # TODO: only output a preview image when requested + image = Generator.sample_to_lowres_estimated_image(sample) + + (width, height) = image.size + width *= 8 + height *= 8 + + dataURL = image_to_dataURL(image, image_format="JPEG") + + context.services.events.emit_generator_progress( + context.graph_execution_state_id, + self.id, + { + "width": width, + "height": height, + "dataURL": dataURL + }, + step, + self.steps, + ) + + def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline: + model_info = model_manager.get_model(self.model) + model_name = model_info['model_name'] + model_hash = model_info['hash'] + model: StableDiffusionGeneratorPipeline = model_info['model'] + model.scheduler = get_scheduler( + model=model, + scheduler_name=self.sampler_name + ) + + if isinstance(model, DiffusionPipeline): + for component in [model.unet, model.vae]: + configure_model_padding(component, + self.seamless, + self.seamless_axes + ) + else: + configure_model_padding(model, + self.seamless, + self.seamless_axes + ) + + return model + + + def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData: + uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model) + conditioning_data = ConditioningData( + uc, + c, + self.cfg_scale, + extra_conditioning_info, + postprocessing_settings=PostprocessingSettings( + threshold=0.0,#threshold, + warmup=0.2,#warmup, + h_symmetry_time_pct=None,#h_symmetry_time_pct, + v_symmetry_time_pct=None#v_symmetry_time_pct, + ), + ).add_scheduler_args_if_applicable(model.scheduler, eta=None)#ddim_eta) + return conditioning_data + + + def invoke(self, context: InvocationContext) -> LatentsOutput: + noise = context.services.latents.get(self.noise.latents_name) + + def step_callback(state: PipelineIntermediateState): + self.dispatch_progress(context, state.latents, state.step) + + model = self.get_model(context.services.model_manager) + conditioning_data = self.get_conditioning_data(model) + + # TODO: Verify the noise is the right size + + result_latents, result_attention_map_saver = model.latents_from_embeddings( + latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)), + noise=noise, + num_inference_steps=self.steps, + conditioning_data=conditioning_data, + callback=step_callback + ) + + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + torch.cuda.empty_cache() + + name = f'{context.graph_execution_state_id}__{self.id}' + context.services.latents.set(name, result_latents) + return LatentsOutput( + latents=LatentsField(latents_name=name) + ) + + +class LatentsToLatentsInvocation(TextToLatentsInvocation): + """Generates latents using latents as base image.""" + + type: Literal["l2l"] = "l2l" + + # Inputs + latents: Optional[LatentsField] = Field(description="The latents to use as a base image") + strength: float = Field(default=0.5, description="The strength of the latents to use") + + def invoke(self, context: InvocationContext) -> LatentsOutput: + noise = context.services.latents.get(self.noise.latents_name) + latent = context.services.latents.get(self.latents.latents_name) + + def step_callback(state: PipelineIntermediateState): + self.dispatch_progress(context, state.latents, state.step) + + model = self.get_model(context.services.model_manager) + conditioning_data = self.get_conditioning_data(model) + + # TODO: Verify the noise is the right size + + initial_latents = latent if self.strength < 1.0 else torch.zeros_like( + latent, device=model.device, dtype=latent.dtype + ) + + timesteps, _ = model.get_img2img_timesteps( + self.steps, + self.strength, + device=model.device, + ) + + result_latents, result_attention_map_saver = model.latents_from_embeddings( + latents=initial_latents, + timesteps=timesteps, + noise=noise, + num_inference_steps=self.steps, + conditioning_data=conditioning_data, + callback=step_callback + ) + + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + torch.cuda.empty_cache() + + name = f'{context.graph_execution_state_id}__{self.id}' + context.services.latents.set(name, result_latents) + return LatentsOutput( + latents=LatentsField(latents_name=name) + ) + + +# Latent to image +class LatentsToImageInvocation(BaseInvocation): + """Generates an image from latents.""" + + type: Literal["l2i"] = "l2i" + + # Inputs + latents: Optional[LatentsField] = Field(description="The latents to generate an image from") + model: str = Field(default="", description="The model to use") + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> ImageOutput: + latents = context.services.latents.get(self.latents.latents_name) + + # TODO: this only really needs the vae + model_info = context.services.model_manager.get_model(self.model) + model: StableDiffusionGeneratorPipeline = model_info['model'] + + with torch.inference_mode(): + np_image = model.decode_latents(latents) + image = model.numpy_to_pil(np_image)[0] + + image_type = ImageType.RESULT + image_name = context.services.images.create_name( + context.graph_execution_state_id, self.id + ) + context.services.images.save(image_type, image_name, image) + return ImageOutput( + image=ImageField(image_type=image_type, image_name=image_name) + ) diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py new file mode 100644 index 0000000000..ecdcc834c7 --- /dev/null +++ b/invokeai/app/invocations/math.py @@ -0,0 +1,68 @@ +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) + +from datetime import datetime, timezone +from typing import Literal, Optional + +import numpy +from PIL import Image, ImageFilter, ImageOps +from pydantic import BaseModel, Field + +from ..services.image_storage import ImageType +from ..services.invocation_services import InvocationServices +from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext + + +class IntOutput(BaseInvocationOutput): + """An integer output""" + #fmt: off + type: Literal["int_output"] = "int_output" + a: int = Field(default=None, description="The output integer") + #fmt: on + + +class AddInvocation(BaseInvocation): + """Adds two numbers""" + #fmt: off + type: Literal["add"] = "add" + a: int = Field(default=0, description="The first number") + b: int = Field(default=0, description="The second number") + #fmt: on + + def invoke(self, context: InvocationContext) -> IntOutput: + return IntOutput(a=self.a + self.b) + + +class SubtractInvocation(BaseInvocation): + """Subtracts two numbers""" + #fmt: off + type: Literal["sub"] = "sub" + a: int = Field(default=0, description="The first number") + b: int = Field(default=0, description="The second number") + #fmt: on + + def invoke(self, context: InvocationContext) -> IntOutput: + return IntOutput(a=self.a - self.b) + + +class MultiplyInvocation(BaseInvocation): + """Multiplies two numbers""" + #fmt: off + type: Literal["mul"] = "mul" + a: int = Field(default=0, description="The first number") + b: int = Field(default=0, description="The second number") + #fmt: on + + def invoke(self, context: InvocationContext) -> IntOutput: + return IntOutput(a=self.a * self.b) + + +class DivideInvocation(BaseInvocation): + """Divides two numbers""" + #fmt: off + type: Literal["div"] = "div" + a: int = Field(default=0, description="The first number") + b: int = Field(default=0, description="The second number") + #fmt: on + + def invoke(self, context: InvocationContext) -> IntOutput: + return IntOutput(a=int(self.a / self.b)) diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py index 171d86c9e3..98c2f29308 100644 --- a/invokeai/app/services/graph.py +++ b/invokeai/app/services/graph.py @@ -1069,9 +1069,8 @@ class GraphExecutionState(BaseModel): n for n in prepared_nodes if all( - pit + nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators - if nx.has_path(execution_graph, pit[0], n) ) ), None, diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 7f24c34378..2cd0f55fd9 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -2,6 +2,7 @@ from invokeai.backend import ModelManager from .events import EventServiceBase +from .latent_storage import LatentsStorageBase from .image_storage import ImageStorageBase from .restoration_services import RestorationServices from .invocation_queue import InvocationQueueABC @@ -11,6 +12,7 @@ class InvocationServices: """Services that can be used by invocations""" events: EventServiceBase + latents: LatentsStorageBase images: ImageStorageBase queue: InvocationQueueABC model_manager: ModelManager @@ -24,6 +26,7 @@ class InvocationServices: self, model_manager: ModelManager, events: EventServiceBase, + latents: LatentsStorageBase, images: ImageStorageBase, queue: InvocationQueueABC, graph_execution_manager: ItemStorageABC["GraphExecutionState"], @@ -32,6 +35,7 @@ class InvocationServices: ): self.model_manager = model_manager self.events = events + self.latents = latents self.images = images self.queue = queue self.graph_execution_manager = graph_execution_manager diff --git a/invokeai/app/services/invoker.py b/invokeai/app/services/invoker.py index 594477ed0f..e3fa6da851 100644 --- a/invokeai/app/services/invoker.py +++ b/invokeai/app/services/invoker.py @@ -33,7 +33,6 @@ class Invoker: self.services.graph_execution_manager.set(graph_execution_state) # Queue the invocation - print(f"queueing item {invocation.id}") self.services.queue.put( InvocationQueueItem( # session_id = session.id, diff --git a/invokeai/app/services/latent_storage.py b/invokeai/app/services/latent_storage.py new file mode 100644 index 0000000000..0184692e05 --- /dev/null +++ b/invokeai/app/services/latent_storage.py @@ -0,0 +1,93 @@ +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) + +import os +from abc import ABC, abstractmethod +from pathlib import Path +from queue import Queue +from typing import Dict + +import torch + +class LatentsStorageBase(ABC): + """Responsible for storing and retrieving latents.""" + + @abstractmethod + def get(self, name: str) -> torch.Tensor: + pass + + @abstractmethod + def set(self, name: str, data: torch.Tensor) -> None: + pass + + @abstractmethod + def delete(self, name: str) -> None: + pass + + +class ForwardCacheLatentsStorage(LatentsStorageBase): + """Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage""" + + __cache: Dict[str, torch.Tensor] + __cache_ids: Queue + __max_cache_size: int + __underlying_storage: LatentsStorageBase + + def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20): + self.__underlying_storage = underlying_storage + self.__cache = dict() + self.__cache_ids = Queue() + self.__max_cache_size = max_cache_size + + def get(self, name: str) -> torch.Tensor: + cache_item = self.__get_cache(name) + if cache_item is not None: + return cache_item + + latent = self.__underlying_storage.get(name) + self.__set_cache(name, latent) + return latent + + def set(self, name: str, data: torch.Tensor) -> None: + self.__underlying_storage.set(name, data) + self.__set_cache(name, data) + + def delete(self, name: str) -> None: + self.__underlying_storage.delete(name) + if name in self.__cache: + del self.__cache[name] + + def __get_cache(self, name: str) -> torch.Tensor|None: + return None if name not in self.__cache else self.__cache[name] + + def __set_cache(self, name: str, data: torch.Tensor): + if not name in self.__cache: + self.__cache[name] = data + self.__cache_ids.put(name) + if self.__cache_ids.qsize() > self.__max_cache_size: + self.__cache.pop(self.__cache_ids.get()) + + +class DiskLatentsStorage(LatentsStorageBase): + """Stores latents in a folder on disk without caching""" + + __output_folder: str + + def __init__(self, output_folder: str): + self.__output_folder = output_folder + Path(output_folder).mkdir(parents=True, exist_ok=True) + + def get(self, name: str) -> torch.Tensor: + latent_path = self.get_path(name) + return torch.load(latent_path) + + def set(self, name: str, data: torch.Tensor) -> None: + latent_path = self.get_path(name) + torch.save(data, latent_path) + + def delete(self, name: str) -> None: + latent_path = self.get_path(name) + os.remove(latent_path) + + def get_path(self, name: str) -> str: + return os.path.join(self.__output_folder, name) + \ No newline at end of file diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index b722539935..506b8653f8 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -1,6 +1,8 @@ from .test_invoker import create_edge from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext +from invokeai.app.invocations.collections import RangeInvocation +from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation from invokeai.app.services.processor import DefaultInvocationProcessor from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory from invokeai.app.services.invocation_queue import MemoryInvocationQueue @@ -21,13 +23,14 @@ def simple_graph(): def mock_services(): # NOTE: none of these are actually called by the test invocations return InvocationServices( - model_manager = None, - events = None, - images = None, + model_manager = None, # type: ignore + events = None, # type: ignore + images = None, # type: ignore + latents = None, # type: ignore queue = MemoryInvocationQueue(), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), processor = DefaultInvocationProcessor(), - restoration = None, + restoration = None, # type: ignore ) def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]: @@ -73,31 +76,23 @@ def test_graph_is_not_complete(simple_graph, mock_services): def test_graph_state_expands_iterator(mock_services): graph = Graph() - test_prompts = ["Banana sushi", "Cat sushi"] - graph.add_node(PromptCollectionTestInvocation(id = "1", collection = list(test_prompts))) - graph.add_node(IterateInvocation(id = "2")) - graph.add_node(ImageTestInvocation(id = "3")) - graph.add_edge(create_edge("1", "collection", "2", "collection")) - graph.add_edge(create_edge("2", "item", "3", "prompt")) + graph.add_node(RangeInvocation(id = "0", start = 0, stop = 3, step = 1)) + graph.add_node(IterateInvocation(id = "1")) + graph.add_node(MultiplyInvocation(id = "2", b = 10)) + graph.add_node(AddInvocation(id = "3", b = 1)) + graph.add_edge(create_edge("0", "collection", "1", "collection")) + graph.add_edge(create_edge("1", "item", "2", "a")) + graph.add_edge(create_edge("2", "a", "3", "a")) g = GraphExecutionState(graph = graph) - n1 = invoke_next(g, mock_services) - n2 = invoke_next(g, mock_services) - n3 = invoke_next(g, mock_services) - n4 = invoke_next(g, mock_services) - n5 = invoke_next(g, mock_services) + while not g.is_complete(): + invoke_next(g, mock_services) + + prepared_add_nodes = g.source_prepared_mapping['3'] + results = set([g.results[n].a for n in prepared_add_nodes]) + expected = set([1, 11, 21]) + assert results == expected - assert g.prepared_source_mapping[n1[0].id] == "1" - assert g.prepared_source_mapping[n2[0].id] == "2" - assert g.prepared_source_mapping[n3[0].id] == "2" - assert g.prepared_source_mapping[n4[0].id] == "3" - assert g.prepared_source_mapping[n5[0].id] == "3" - - assert isinstance(n4[0], ImageTestInvocation) - assert isinstance(n5[0], ImageTestInvocation) - - prompts = [n4[0].prompt, n5[0].prompt] - assert sorted(prompts) == sorted(test_prompts) def test_graph_state_collects(mock_services): graph = Graph() diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index 718baa7a1f..68df708bdd 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -24,10 +24,11 @@ def mock_services() -> InvocationServices: model_manager = None, # type: ignore events = TestEventService(), images = None, # type: ignore + latents = None, # type: ignore queue = MemoryInvocationQueue(), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), processor = DefaultInvocationProcessor(), - restoration = None, + restoration = None, # type: ignore ) @pytest.fixture()