From de189f2db6c4d5dbf893b93300b10fe1e3a77d5a Mon Sep 17 00:00:00 2001 From: AbdBarho Date: Sun, 9 Apr 2023 21:53:59 +0200 Subject: [PATCH 1/9] Increase chunk size when computing SHAs --- invokeai/backend/model_management/model_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index a51a2fec22..534b526081 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -1204,7 +1204,7 @@ class ModelManager(object): return self.device.type == "cuda" def _diffuser_sha256( - self, name_or_path: Union[str, Path], chunksize=4096 + self, name_or_path: Union[str, Path], chunksize=16777216 ) -> Union[str, bytes]: path = None if isinstance(name_or_path, Path): From 5bd0bb637f11fb4acc8ac504300a85f347a1df8a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 10 Apr 2023 18:14:06 +1000 Subject: [PATCH 2/9] fix(nodes): add missing type to `ImageField` --- invokeai/app/models/image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/models/image.py b/invokeai/app/models/image.py index 1561e6bcc5..9edb16800d 100644 --- a/invokeai/app/models/image.py +++ b/invokeai/app/models/image.py @@ -12,7 +12,7 @@ class ImageType(str, Enum): class ImageField(BaseModel): """An image field used for passing image objects between invocations""" - image_type: str = Field( + image_type: ImageType = Field( default=ImageType.RESULT, description="The type of the image" ) image_name: Optional[str] = Field(default=None, description="The name of the image") From dad3a7f263ce1774a3454d305f7e62e22a909fb3 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 10 Apr 2023 18:13:23 +1000 Subject: [PATCH 3/9] fix(nodes): `sampler_name` --> `scheduler` the name of this was changed at some point. nodes still used the old name, so scheduler selection did nothing. simple fix. --- invokeai/app/invocations/generate.py | 2 +- invokeai/app/invocations/latent.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index 153d11189e..70c695fd2e 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -35,7 +35,7 @@ class TextToImageInvocation(BaseInvocation): width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", ) height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", ) cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) - sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use" ) + scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" ) seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) model: str = Field(default="", description="The model to use (currently ignored)") progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", ) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 49c3c4f11e..ca3c7246c7 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -136,7 +136,7 @@ class TextToLatentsInvocation(BaseInvocation): width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", ) height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", ) cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) - sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use" ) + scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" ) seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") model: str = Field(default="", description="The model to use (currently ignored)") @@ -175,7 +175,7 @@ class TextToLatentsInvocation(BaseInvocation): model: StableDiffusionGeneratorPipeline = model_info['model'] model.scheduler = get_scheduler( model=model, - scheduler_name=self.sampler_name + scheduler_name=self.scheduler ) if isinstance(model, DiffusionPipeline): From 427db7c7e2dd4b54cf10b4f4328a16c99da9151e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 9 Apr 2023 22:33:16 +1000 Subject: [PATCH 4/9] feat(nodes): fix typo in PasteImageInvocation --- invokeai/app/invocations/image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 491a4895a6..9f072036a4 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -139,7 +139,7 @@ class PasteImageInvocation(BaseInvocation): None if self.mask is None else ImageOps.invert( - services.images.get(self.mask.image_type, self.mask.image_name) + context.services.images.get(self.mask.image_type, self.mask.image_name) ) ) # TODO: probably shouldn't invert mask here... should user be required to do it? From 07e3a0ec1545ee64ecf76c1d118a8e995b5a4a2c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 10 Apr 2023 19:07:48 +1000 Subject: [PATCH 5/9] feat(nodes): add invocation schema customisation, add model selection - add invocation schema customisation done via fastapi's `Config` class and `schema_extra`. when using `Config`, inherit from `InvocationConfig` to get type hints. where it makes sense - like for all math invocations - define a `MathInvocationConfig` class and have all invocations inherit from it. this customisation can provide any arbitrary additional data to the UI. currently it provides tags and field type hints. this is necessary for `model` type fields, which are actually string fields. without something like this, we can't reliably differentiate `model` fields from normal `string` fields. can also be used for future field types. all invocations now have tags, and all `model` fields have ui type hints. - fix model handling for invocations added a helper to fall back to the default model if an invalid model name is chosen. model names in graphs now work. - fix latents progress callback noticed this wasn't correct while working on everything else. --- invokeai/app/invocations/cv.py | 17 ++++- invokeai/app/invocations/generate.py | 49 +++++++----- invokeai/app/invocations/image.py | 25 ++++-- invokeai/app/invocations/latent.py | 88 +++++++++++++++------- invokeai/app/invocations/math.py | 29 ++++--- invokeai/app/invocations/models/config.py | 54 +++++++++++++ invokeai/app/invocations/reconstruct.py | 9 +++ invokeai/app/invocations/upscale.py | 10 +++ invokeai/app/invocations/util/get_model.py | 11 +++ 9 files changed, 228 insertions(+), 64 deletions(-) create mode 100644 invokeai/app/invocations/models/config.py create mode 100644 invokeai/app/invocations/util/get_model.py diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index ce784313cf..adcb2405c4 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -5,14 +5,27 @@ from typing import Literal import cv2 as cv import numpy from PIL import Image, ImageOps -from pydantic import Field +from pydantic import BaseModel, Field +from invokeai.app.invocations.models.config import InvocationConfig from invokeai.app.models.image import ImageField, ImageType from .baseinvocation import BaseInvocation, InvocationContext from .image import ImageOutput -class CvInpaintInvocation(BaseInvocation): +class CVInvocation(BaseModel): + """Helper class to provide all OpenCV invocations with additional config""" + + # Schema customisation + class Config(InvocationConfig): + schema_extra = { + "ui": { + "tags": ["cv", "image"], + }, + } + + +class CvInpaintInvocation(BaseInvocation, CVInvocation): """Simple inpaint using opencv.""" #fmt: off type: Literal["cv_inpaint"] = "cv_inpaint" diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index 70c695fd2e..a07b1ac379 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -6,9 +6,13 @@ from typing import Literal, Optional, Union import numpy as np from torch import Tensor -from pydantic import Field +from pydantic import BaseModel, Field +from invokeai.app.invocations.models.config import ( + InvocationConfig, +) from invokeai.app.models.image import ImageField, ImageType +from invokeai.app.invocations.util.get_model import choose_model from .baseinvocation import BaseInvocation, InvocationContext from .image import ImageOutput from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator @@ -16,12 +20,26 @@ from ...backend.stable_diffusion import PipelineIntermediateState from ..models.exceptions import CanceledException from ..util.step_callback import diffusers_step_callback_adapter -SAMPLER_NAME_VALUES = Literal[ - tuple(InvokeAIGenerator.schedulers()) -] +SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())] + + +class SDImageInvocation(BaseModel): + """Helper class to provide all Stable Diffusion raster image invocations with additional config""" + + # Schema customisation + class Config(InvocationConfig): + schema_extra = { + "ui": { + "tags": ["stable-diffusion", "image"], + "type_hints": { + "model": "model", + }, + }, + } + # Text to image -class TextToImageInvocation(BaseInvocation): +class TextToImageInvocation(BaseInvocation, SDImageInvocation): """Generates an image using text2img.""" type: Literal["txt2img"] = "txt2img" @@ -59,16 +77,9 @@ class TextToImageInvocation(BaseInvocation): diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context) def invoke(self, context: InvocationContext) -> ImageOutput: - # def step_callback(state: PipelineIntermediateState): - # if (context.services.queue.is_canceled(context.graph_execution_state_id)): - # raise CanceledException - # self.dispatch_progress(context, state.latents, state.step) - # Handle invalid model parameter - # TODO: figure out if this can be done via a validator that uses the model_cache - # TODO: How to get the default model name now? - # (right now uses whatever current model is set in model manager) - model= context.services.model_manager.get_model() + model = choose_model(context.services.model_manager, self.model) + outputs = Txt2Img(model).generate( prompt=self.prompt, step_callback=partial(self.dispatch_progress, context), @@ -135,9 +146,8 @@ class ImageToImageInvocation(TextToImageInvocation): mask = None # Handle invalid model parameter - # TODO: figure out if this can be done via a validator that uses the model_cache - # TODO: How to get the default model name now? - model = context.services.model_manager.get_model() + model = choose_model(context.services.model_manager, self.model) + outputs = Img2Img(model).generate( prompt=self.prompt, init_image=image, @@ -211,9 +221,8 @@ class InpaintInvocation(ImageToImageInvocation): ) # Handle invalid model parameter - # TODO: figure out if this can be done via a validator that uses the model_cache - # TODO: How to get the default model name now? - model = context.services.model_manager.get_model() + model = choose_model(context.services.model_manager, self.model) + outputs = Inpaint(model).generate( prompt=self.prompt, init_img=image, diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 9f072036a4..0f783b2541 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -7,10 +7,23 @@ import numpy from PIL import Image, ImageFilter, ImageOps from pydantic import BaseModel, Field +from invokeai.app.invocations.models.config import InvocationConfig + from ..models.image import ImageField, ImageType from ..services.invocation_services import InvocationServices from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext + +class PILInvocationConfig(BaseModel): + """Helper class to provide all PIL invocations with additional config""" + + class Config(InvocationConfig): + schema_extra = { + "ui": { + "tags": ["PIL", "image"], + }, + } + class ImageOutput(BaseInvocationOutput): """Base class for invocations that output an image""" #fmt: off @@ -82,7 +95,7 @@ class ShowImageInvocation(BaseInvocation): ) -class CropImageInvocation(BaseInvocation): +class CropImageInvocation(BaseInvocation, PILInvocationConfig): """Crops an image to a specified box. The box can be outside of the image.""" #fmt: off type: Literal["crop"] = "crop" @@ -115,7 +128,7 @@ class CropImageInvocation(BaseInvocation): ) -class PasteImageInvocation(BaseInvocation): +class PasteImageInvocation(BaseInvocation, PILInvocationConfig): """Pastes an image into another image.""" #fmt: off type: Literal["paste"] = "paste" @@ -165,7 +178,7 @@ class PasteImageInvocation(BaseInvocation): ) -class MaskFromAlphaInvocation(BaseInvocation): +class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig): """Extracts the alpha channel of an image as a mask.""" #fmt: off type: Literal["tomask"] = "tomask" @@ -192,7 +205,7 @@ class MaskFromAlphaInvocation(BaseInvocation): return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name)) -class BlurInvocation(BaseInvocation): +class BlurInvocation(BaseInvocation, PILInvocationConfig): """Blurs an image""" #fmt: off @@ -226,7 +239,7 @@ class BlurInvocation(BaseInvocation): ) -class LerpInvocation(BaseInvocation): +class LerpInvocation(BaseInvocation, PILInvocationConfig): """Linear interpolation of all pixels of an image""" #fmt: off type: Literal["lerp"] = "lerp" @@ -257,7 +270,7 @@ class LerpInvocation(BaseInvocation): ) -class InverseLerpInvocation(BaseInvocation): +class InverseLerpInvocation(BaseInvocation, PILInvocationConfig): """Inverse linear interpolation of all pixels of an image""" #fmt: off type: Literal["ilerp"] = "ilerp" diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index ca3c7246c7..d59d681f40 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -2,9 +2,13 @@ from typing import Literal, Optional from pydantic import BaseModel, Field -from torch import Tensor import torch +from invokeai.app.invocations.models.config import InvocationConfig +from invokeai.app.models.exceptions import CanceledException +from invokeai.app.invocations.util.get_model import choose_model +from invokeai.app.util.step_callback import diffusers_step_callback_adapter + from ...backend.model_management.model_manager import ModelManager from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings @@ -13,13 +17,10 @@ from ...backend.prompting.conditioning import get_uc_and_c_and_ec from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext import numpy as np -from accelerate.utils import set_seed from ..services.image_storage import ImageType from .baseinvocation import BaseInvocation, InvocationContext from .image import ImageField, ImageOutput -from ...backend.generator import Generator from ...backend.stable_diffusion import PipelineIntermediateState -from ...backend.util.util import image_to_dataURL from diffusers.schedulers import SchedulerMixin as Scheduler import diffusers from diffusers import DiffusionPipeline @@ -109,6 +110,15 @@ class NoiseInvocation(BaseInvocation): width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", ) height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", ) + + # Schema customisation + class Config(InvocationConfig): + schema_extra = { + "ui": { + "tags": ["latents", "noise"], + }, + } + def invoke(self, context: InvocationContext) -> NoiseOutput: device = torch.device(choose_torch_device()) noise = get_noise(self.width, self.height, device, self.seed) @@ -143,33 +153,37 @@ class TextToLatentsInvocation(BaseInvocation): progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", ) # fmt: on + # Schema customisation + class Config(InvocationConfig): + schema_extra = { + "ui": { + "tags": ["latents", "image"], + "type_hints": { + "model": "model" + } + }, + } + # TODO: pass this an emitter method or something? or a session for dispatching? def dispatch_progress( - self, context: InvocationContext, sample: Tensor, step: int + self, context: InvocationContext, intermediate_state: PipelineIntermediateState ) -> None: - # TODO: only output a preview image when requested - image = Generator.sample_to_lowres_estimated_image(sample) + if (context.services.queue.is_canceled(context.graph_execution_state_id)): + raise CanceledException - (width, height) = image.size - width *= 8 - height *= 8 + step = intermediate_state.step + if intermediate_state.predicted_original is not None: + # Some schedulers report not only the noisy latents at the current timestep, + # but also their estimate so far of what the de-noised latents will be. + sample = intermediate_state.predicted_original + else: + sample = intermediate_state.latents - dataURL = image_to_dataURL(image, image_format="JPEG") + diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context) - context.services.events.emit_generator_progress( - context.graph_execution_state_id, - self.id, - { - "width": width, - "height": height, - "dataURL": dataURL - }, - step, - self.steps, - ) def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline: - model_info = model_manager.get_model(self.model) + model_info = choose_model(model_manager, self.model) model_name = model_info['model_name'] model_hash = model_info['hash'] model: StableDiffusionGeneratorPipeline = model_info['model'] @@ -214,7 +228,7 @@ class TextToLatentsInvocation(BaseInvocation): noise = context.services.latents.get(self.noise.latents_name) def step_callback(state: PipelineIntermediateState): - self.dispatch_progress(context, state.latents, state.step) + self.dispatch_progress(context, state) model = self.get_model(context.services.model_manager) conditioning_data = self.get_conditioning_data(model) @@ -244,6 +258,17 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): type: Literal["l2l"] = "l2l" + # Schema customisation + class Config(InvocationConfig): + schema_extra = { + "ui": { + "tags": ["latents"], + "type_hints": { + "model": "model" + } + }, + } + # Inputs latents: Optional[LatentsField] = Field(description="The latents to use as a base image") strength: float = Field(default=0.5, description="The strength of the latents to use") @@ -253,7 +278,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): latent = context.services.latents.get(self.latents.latents_name) def step_callback(state: PipelineIntermediateState): - self.dispatch_progress(context, state.latents, state.step) + self.dispatch_progress(context, state) model = self.get_model(context.services.model_manager) conditioning_data = self.get_conditioning_data(model) @@ -299,12 +324,23 @@ class LatentsToImageInvocation(BaseInvocation): latents: Optional[LatentsField] = Field(description="The latents to generate an image from") model: str = Field(default="", description="The model to use") + # Schema customisation + class Config(InvocationConfig): + schema_extra = { + "ui": { + "tags": ["latents", "image"], + "type_hints": { + "model": "model" + } + }, + } + @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.services.latents.get(self.latents.latents_name) # TODO: this only really needs the vae - model_info = context.services.model_manager.get_model(self.model) + model_info = choose_model(context.services.model_manager, self.model) model: StableDiffusionGeneratorPipeline = model_info['model'] with torch.inference_mode(): diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py index ecdcc834c7..e9d90b1d61 100644 --- a/invokeai/app/invocations/math.py +++ b/invokeai/app/invocations/math.py @@ -1,17 +1,26 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) -from datetime import datetime, timezone -from typing import Literal, Optional +from typing import Literal -import numpy -from PIL import Image, ImageFilter, ImageOps from pydantic import BaseModel, Field -from ..services.image_storage import ImageType -from ..services.invocation_services import InvocationServices +from invokeai.app.invocations.models.config import InvocationConfig + from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext +class MathInvocationConfig(BaseModel): + """Helper class to provide all math invocations with additional config""" + + # Schema customisation + class Config(InvocationConfig): + schema_extra = { + "ui": { + "tags": ["math"], + } + } + + class IntOutput(BaseInvocationOutput): """An integer output""" #fmt: off @@ -20,7 +29,7 @@ class IntOutput(BaseInvocationOutput): #fmt: on -class AddInvocation(BaseInvocation): +class AddInvocation(BaseInvocation, MathInvocationConfig): """Adds two numbers""" #fmt: off type: Literal["add"] = "add" @@ -32,7 +41,7 @@ class AddInvocation(BaseInvocation): return IntOutput(a=self.a + self.b) -class SubtractInvocation(BaseInvocation): +class SubtractInvocation(BaseInvocation, MathInvocationConfig): """Subtracts two numbers""" #fmt: off type: Literal["sub"] = "sub" @@ -44,7 +53,7 @@ class SubtractInvocation(BaseInvocation): return IntOutput(a=self.a - self.b) -class MultiplyInvocation(BaseInvocation): +class MultiplyInvocation(BaseInvocation, MathInvocationConfig): """Multiplies two numbers""" #fmt: off type: Literal["mul"] = "mul" @@ -56,7 +65,7 @@ class MultiplyInvocation(BaseInvocation): return IntOutput(a=self.a * self.b) -class DivideInvocation(BaseInvocation): +class DivideInvocation(BaseInvocation, MathInvocationConfig): """Divides two numbers""" #fmt: off type: Literal["div"] = "div" diff --git a/invokeai/app/invocations/models/config.py b/invokeai/app/invocations/models/config.py new file mode 100644 index 0000000000..f53bbdda00 --- /dev/null +++ b/invokeai/app/invocations/models/config.py @@ -0,0 +1,54 @@ +from typing import Dict, List, Literal, TypedDict +from pydantic import BaseModel + + +# TODO: when we can upgrade to python 3.11, we can use the`NotRequired` type instead of `total=False` +class UIConfig(TypedDict, total=False): + type_hints: Dict[ + str, + Literal[ + "integer", + "float", + "boolean", + "string", + "enum", + "image", + "latents", + "model", + ], + ] + tags: List[str] + + +class CustomisedSchemaExtra(TypedDict): + ui: UIConfig + + +class InvocationConfig(BaseModel.Config): + """Customizes pydantic's BaseModel.Config class for use by Invocations. + + Provide `schema_extra` a `ui` dict to add hints for generated UIs. + + `tags` + - A list of strings, used to categorise invocations. + + `type_hints` + - A dict of field types which override the types in the invocation definition. + - Each key should be the name of one of the invocation's fields. + - Each value should be one of the valid types: + - `integer`, `float`, `boolean`, `string`, `enum`, `image`, `latents`, `model` + + ```python + class Config(InvocationConfig): + schema_extra = { + "ui": { + "tags": ["stable-diffusion", "image"], + "type_hints": { + "initial_image": "image", + }, + }, + } + ``` + """ + + schema_extra: CustomisedSchemaExtra diff --git a/invokeai/app/invocations/reconstruct.py b/invokeai/app/invocations/reconstruct.py index 68449729d6..db0a28c075 100644 --- a/invokeai/app/invocations/reconstruct.py +++ b/invokeai/app/invocations/reconstruct.py @@ -2,6 +2,7 @@ from datetime import datetime, timezone from typing import Literal, Union from pydantic import Field +from invokeai.app.invocations.models.config import InvocationConfig from invokeai.app.models.image import ImageField, ImageType from ..services.invocation_services import InvocationServices @@ -18,6 +19,14 @@ class RestoreFaceInvocation(BaseInvocation): strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" ) #fmt: on + # Schema customisation + class Config(InvocationConfig): + schema_extra = { + "ui": { + "tags": ["restoration", "image"], + }, + } + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get( self.image.image_type, self.image.image_name diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index ea3221572e..c9aa86cc6d 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -4,6 +4,7 @@ from datetime import datetime, timezone from typing import Literal, Union from pydantic import Field +from invokeai.app.invocations.models.config import InvocationConfig from invokeai.app.models.image import ImageField, ImageType from ..services.invocation_services import InvocationServices @@ -22,6 +23,15 @@ class UpscaleInvocation(BaseInvocation): level: Literal[2, 4] = Field(default=2, description="The upscale level") #fmt: on + + # Schema customisation + class Config(InvocationConfig): + schema_extra = { + "ui": { + "tags": ["upscaling", "image"], + }, + } + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get( self.image.image_type, self.image.image_name diff --git a/invokeai/app/invocations/util/get_model.py b/invokeai/app/invocations/util/get_model.py new file mode 100644 index 0000000000..d3484a0b9d --- /dev/null +++ b/invokeai/app/invocations/util/get_model.py @@ -0,0 +1,11 @@ +from invokeai.app.invocations.baseinvocation import InvocationContext +from invokeai.backend.model_management.model_manager import ModelManager + + +def choose_model(model_manager: ModelManager, model_name: str): + """Returns the default model if the `model_name` not a valid model, else returns the selected model.""" + if model_manager.valid_model(model_name): + return model_manager.get_model(model_name) + else: + print(f"* Warning: '{model_name}' is not a valid model name. Using default model instead.") + return model_manager.get_model() \ No newline at end of file From 1f2c1e14dbf45220bf46dfd5d05312ebab408a88 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 11 Apr 2023 08:24:48 +1000 Subject: [PATCH 6/9] fix(nodes): move InvocationConfig to baseinvocation.py --- invokeai/app/invocations/baseinvocation.py | 55 +++++++++++++++++++++- invokeai/app/invocations/cv.py | 3 +- invokeai/app/invocations/generate.py | 5 +- invokeai/app/invocations/image.py | 4 +- invokeai/app/invocations/latent.py | 3 +- invokeai/app/invocations/math.py | 4 +- invokeai/app/invocations/models/config.py | 54 --------------------- invokeai/app/invocations/reconstruct.py | 3 +- invokeai/app/invocations/upscale.py | 3 +- 9 files changed, 61 insertions(+), 73 deletions(-) delete mode 100644 invokeai/app/invocations/models/config.py diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 72fe39ed0b..3590129b96 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from inspect import signature -from typing import get_args, get_type_hints +from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict from pydantic import BaseModel, Field @@ -76,3 +76,56 @@ class BaseInvocation(ABC, BaseModel): #fmt: off id: str = Field(description="The id of this node. Must be unique among all nodes.") #fmt: on + + +# TODO: figure out a better way to provide these hints +# TODO: when we can upgrade to python 3.11, we can use the`NotRequired` type instead of `total=False` +class UIConfig(TypedDict, total=False): + type_hints: Dict[ + str, + Literal[ + "integer", + "float", + "boolean", + "string", + "enum", + "image", + "latents", + "model", + ], + ] + tags: List[str] + + +class CustomisedSchemaExtra(TypedDict): + ui: UIConfig + + +class InvocationConfig(BaseModel.Config): + """Customizes pydantic's BaseModel.Config class for use by Invocations. + + Provide `schema_extra` a `ui` dict to add hints for generated UIs. + + `tags` + - A list of strings, used to categorise invocations. + + `type_hints` + - A dict of field types which override the types in the invocation definition. + - Each key should be the name of one of the invocation's fields. + - Each value should be one of the valid types: + - `integer`, `float`, `boolean`, `string`, `enum`, `image`, `latents`, `model` + + ```python + class Config(InvocationConfig): + schema_extra = { + "ui": { + "tags": ["stable-diffusion", "image"], + "type_hints": { + "initial_image": "image", + }, + }, + } + ``` + """ + + schema_extra: CustomisedSchemaExtra diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index adcb2405c4..9afbbbbcc9 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -6,10 +6,9 @@ import cv2 as cv import numpy from PIL import Image, ImageOps from pydantic import BaseModel, Field -from invokeai.app.invocations.models.config import InvocationConfig from invokeai.app.models.image import ImageField, ImageType -from .baseinvocation import BaseInvocation, InvocationContext +from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .image import ImageOutput diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index a07b1ac379..d0eeeae698 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -7,13 +7,10 @@ import numpy as np from torch import Tensor from pydantic import BaseModel, Field -from invokeai.app.invocations.models.config import ( - InvocationConfig, -) from invokeai.app.models.image import ImageField, ImageType from invokeai.app.invocations.util.get_model import choose_model -from .baseinvocation import BaseInvocation, InvocationContext +from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .image import ImageOutput from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator from ...backend.stable_diffusion import PipelineIntermediateState diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 0f783b2541..cc5f6b53c7 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -7,11 +7,9 @@ import numpy from PIL import Image, ImageFilter, ImageOps from pydantic import BaseModel, Field -from invokeai.app.invocations.models.config import InvocationConfig - from ..models.image import ImageField, ImageType from ..services.invocation_services import InvocationServices -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext +from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig class PILInvocationConfig(BaseModel): diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index d59d681f40..2da6e451a9 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -4,7 +4,6 @@ from typing import Literal, Optional from pydantic import BaseModel, Field import torch -from invokeai.app.invocations.models.config import InvocationConfig from invokeai.app.models.exceptions import CanceledException from invokeai.app.invocations.util.get_model import choose_model from invokeai.app.util.step_callback import diffusers_step_callback_adapter @@ -15,7 +14,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import Post from ...backend.image_util.seamless import configure_model_padding from ...backend.prompting.conditioning import get_uc_and_c_and_ec from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext +from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig import numpy as np from ..services.image_storage import ImageType from .baseinvocation import BaseInvocation, InvocationContext diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py index e9d90b1d61..afb0e75377 100644 --- a/invokeai/app/invocations/math.py +++ b/invokeai/app/invocations/math.py @@ -4,9 +4,7 @@ from typing import Literal from pydantic import BaseModel, Field -from invokeai.app.invocations.models.config import InvocationConfig - -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext +from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig class MathInvocationConfig(BaseModel): diff --git a/invokeai/app/invocations/models/config.py b/invokeai/app/invocations/models/config.py deleted file mode 100644 index f53bbdda00..0000000000 --- a/invokeai/app/invocations/models/config.py +++ /dev/null @@ -1,54 +0,0 @@ -from typing import Dict, List, Literal, TypedDict -from pydantic import BaseModel - - -# TODO: when we can upgrade to python 3.11, we can use the`NotRequired` type instead of `total=False` -class UIConfig(TypedDict, total=False): - type_hints: Dict[ - str, - Literal[ - "integer", - "float", - "boolean", - "string", - "enum", - "image", - "latents", - "model", - ], - ] - tags: List[str] - - -class CustomisedSchemaExtra(TypedDict): - ui: UIConfig - - -class InvocationConfig(BaseModel.Config): - """Customizes pydantic's BaseModel.Config class for use by Invocations. - - Provide `schema_extra` a `ui` dict to add hints for generated UIs. - - `tags` - - A list of strings, used to categorise invocations. - - `type_hints` - - A dict of field types which override the types in the invocation definition. - - Each key should be the name of one of the invocation's fields. - - Each value should be one of the valid types: - - `integer`, `float`, `boolean`, `string`, `enum`, `image`, `latents`, `model` - - ```python - class Config(InvocationConfig): - schema_extra = { - "ui": { - "tags": ["stable-diffusion", "image"], - "type_hints": { - "initial_image": "image", - }, - }, - } - ``` - """ - - schema_extra: CustomisedSchemaExtra diff --git a/invokeai/app/invocations/reconstruct.py b/invokeai/app/invocations/reconstruct.py index db0a28c075..f6df5a2254 100644 --- a/invokeai/app/invocations/reconstruct.py +++ b/invokeai/app/invocations/reconstruct.py @@ -2,11 +2,10 @@ from datetime import datetime, timezone from typing import Literal, Union from pydantic import Field -from invokeai.app.invocations.models.config import InvocationConfig from invokeai.app.models.image import ImageField, ImageType from ..services.invocation_services import InvocationServices -from .baseinvocation import BaseInvocation, InvocationContext +from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .image import ImageOutput class RestoreFaceInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index c9aa86cc6d..021f3569e8 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -4,11 +4,10 @@ from datetime import datetime, timezone from typing import Literal, Union from pydantic import Field -from invokeai.app.invocations.models.config import InvocationConfig from invokeai.app.models.image import ImageField, ImageType from ..services.invocation_services import InvocationServices -from .baseinvocation import BaseInvocation, InvocationContext +from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .image import ImageOutput From d923d1d66bc4ac0a0d7cf931442ed85b7c1f00ab Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 11 Apr 2023 11:50:28 +1000 Subject: [PATCH 7/9] fix(nodes): fix naming of CvInvocationConfig --- invokeai/app/invocations/cv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index 9afbbbbcc9..52e59b16ac 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -12,7 +12,7 @@ from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .image import ImageOutput -class CVInvocation(BaseModel): +class CvInvocationConfig(BaseModel): """Helper class to provide all OpenCV invocations with additional config""" # Schema customisation @@ -24,7 +24,7 @@ class CVInvocation(BaseModel): } -class CvInpaintInvocation(BaseInvocation, CVInvocation): +class CvInpaintInvocation(BaseInvocation, CvInvocationConfig): """Simple inpaint using opencv.""" #fmt: off type: Literal["cv_inpaint"] = "cv_inpaint" From c44c19e911741b1386b8fa7a300ec7171eb95918 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20K=C3=B6rfer?= <86005583+nicholaskoerfer@users.noreply.github.com> Date: Thu, 13 Apr 2023 17:42:34 +0200 Subject: [PATCH 8/9] Fixed a Typo. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1b02ced2c9..56946df433 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ installing lots of models. 6. Wait while the installer does its thing. After installing the software, the installer will launch a script that lets you configure InvokeAI and -select a set of starting image generaiton models. +select a set of starting image generation models. 7. Find the folder that InvokeAI was installed into (it is not the same as the unpacked zip file directory!) The default location of this From 23d65e71621aefd98931394c9aece5c712fbb320 Mon Sep 17 00:00:00 2001 From: Kyle Schouviller Date: Thu, 13 Apr 2023 23:41:06 -0700 Subject: [PATCH 9/9] [nodes] Add subgraph library, subgraph usage in CLI, and fix subgraph execution (#3180) * Add latent to latent (img2img equivalent) Fix a CLI bug with multiple links per node * Using "latents" instead of "latent" * [nodes] In-progress implementation of graph library * Add linking to CLI for graph nodes (still broken) * Fix subgraph execution, fix subgraph linking in CLI * Fix LatentsToLatents --- invokeai/app/api/dependencies.py | 9 +- invokeai/app/cli/commands.py | 96 +++++++++---- invokeai/app/cli_app.py | 139 +++++++++++++++---- invokeai/app/invocations/latent.py | 57 +++++++- invokeai/app/invocations/params.py | 18 +++ invokeai/app/services/default_graphs.py | 56 ++++++++ invokeai/app/services/graph.py | 87 +++++++++--- invokeai/app/services/invocation_services.py | 3 + invokeai/app/services/sqlite.py | 31 +---- pyproject.toml | 2 +- tests/nodes/test_graph_execution_state.py | 5 +- tests/nodes/test_invoker.py | 5 +- tests/nodes/test_node_graph.py | 66 ++++++++- 13 files changed, 471 insertions(+), 103 deletions(-) create mode 100644 invokeai/app/invocations/params.py create mode 100644 invokeai/app/services/default_graphs.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 5698d25758..cd5d8a61b2 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -3,12 +3,14 @@ import os from argparse import Namespace +from ..services.default_graphs import create_system_graphs + from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage from ...backend import Globals from ..services.model_manager_initializer import get_model_manager from ..services.restoration_services import RestorationServices -from ..services.graph import GraphExecutionState +from ..services.graph import GraphExecutionState, LibraryGraph from ..services.image_storage import DiskImageStorage from ..services.invocation_queue import MemoryInvocationQueue from ..services.invocation_services import InvocationServices @@ -69,6 +71,9 @@ class ApiDependencies: latents=latents, images=images, queue=MemoryInvocationQueue(), + graph_library=SqliteItemStorage[LibraryGraph]( + filename=db_location, table_name="graphs" + ), graph_execution_manager=SqliteItemStorage[GraphExecutionState]( filename=db_location, table_name="graph_executions" ), @@ -76,6 +81,8 @@ class ApiDependencies: restoration=RestorationServices(config), ) + create_system_graphs(services.graph_library) + ApiDependencies.invoker = Invoker(services) @staticmethod diff --git a/invokeai/app/cli/commands.py b/invokeai/app/cli/commands.py index 4e9c9aa581..5ad4827eb0 100644 --- a/invokeai/app/cli/commands.py +++ b/invokeai/app/cli/commands.py @@ -7,11 +7,40 @@ from pydantic import BaseModel, Field import networkx as nx import matplotlib.pyplot as plt -from ..models.image import ImageField -from ..services.graph import GraphExecutionState +from ..invocations.baseinvocation import BaseInvocation +from ..invocations.image import ImageField +from ..services.graph import GraphExecutionState, LibraryGraph, GraphInvocation, Edge from ..services.invoker import Invoker +def add_field_argument(command_parser, name: str, field, default_override = None): + default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory() + if get_origin(field.type_) == Literal: + allowed_values = get_args(field.type_) + allowed_types = set() + for val in allowed_values: + allowed_types.add(type(val)) + allowed_types_list = list(allowed_types) + field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore + + command_parser.add_argument( + f"--{name}", + dest=name, + type=field_type, + default=default, + choices=allowed_values, + help=field.field_info.description, + ) + else: + command_parser.add_argument( + f"--{name}", + dest=name, + type=field.type_, + default=default, + help=field.field_info.description, + ) + + def add_parsers( subparsers, commands: list[type], @@ -36,30 +65,26 @@ def add_parsers( if name in exclude_fields: continue - if get_origin(field.type_) == Literal: - allowed_values = get_args(field.type_) - allowed_types = set() - for val in allowed_values: - allowed_types.add(type(val)) - allowed_types_list = list(allowed_types) - field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore + add_field_argument(command_parser, name, field) - command_parser.add_argument( - f"--{name}", - dest=name, - type=field_type, - default=field.default if field.default_factory is None else field.default_factory(), - choices=allowed_values, - help=field.field_info.description, - ) - else: - command_parser.add_argument( - f"--{name}", - dest=name, - type=field.type_, - default=field.default if field.default_factory is None else field.default_factory(), - help=field.field_info.description, - ) + +def add_graph_parsers( + subparsers, + graphs: list[LibraryGraph], + add_arguments: Callable[[argparse.ArgumentParser], None]|None = None +): + for graph in graphs: + command_parser = subparsers.add_parser(graph.name, help=graph.description) + + if add_arguments is not None: + add_arguments(command_parser) + + # Add arguments for inputs + for exposed_input in graph.exposed_inputs: + node = graph.graph.get_node(exposed_input.node_path) + field = node.__fields__[exposed_input.field] + default_override = getattr(node, exposed_input.field) + add_field_argument(command_parser, exposed_input.alias, field, default_override) class CliContext: @@ -67,17 +92,38 @@ class CliContext: session: GraphExecutionState parser: argparse.ArgumentParser defaults: dict[str, Any] + graph_nodes: dict[str, str] + nodes_added: list[str] def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser): self.invoker = invoker self.session = session self.parser = parser self.defaults = dict() + self.graph_nodes = dict() + self.nodes_added = list() def get_session(self): self.session = self.invoker.services.graph_execution_manager.get(self.session.id) return self.session + def reset(self): + self.session = self.invoker.create_execution_state() + self.graph_nodes = dict() + self.nodes_added = list() + # Leave defaults unchanged + + def add_node(self, node: BaseInvocation): + self.get_session() + self.session.graph.add_node(node) + self.nodes_added.append(node.id) + self.invoker.services.graph_execution_manager.set(self.session) + + def add_edge(self, edge: Edge): + self.get_session() + self.session.add_edge(edge) + self.invoker.services.graph_execution_manager.set(self.session) + class ExitCli(Exception): """Exception to exit the CLI""" diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index a257825dcc..86fd18ca60 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -13,17 +13,20 @@ from typing import ( from pydantic import BaseModel from pydantic.fields import Field +from .services.default_graphs import create_system_graphs + from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage from ..backend import Args -from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history +from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, get_graph_execution_history from .cli.completer import set_autocompleter from .invocations import * from .invocations.baseinvocation import BaseInvocation from .services.events import EventServiceBase from .services.model_manager_initializer import get_model_manager from .services.restoration_services import RestorationServices -from .services.graph import Edge, EdgeConnection, GraphExecutionState, are_connection_types_compatible +from .services.graph import Edge, EdgeConnection, ExposedNodeInput, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible +from .services.default_graphs import default_text_to_image_graph_id from .services.image_storage import DiskImageStorage from .services.invocation_queue import MemoryInvocationQueue from .services.invocation_services import InvocationServices @@ -58,7 +61,7 @@ def add_invocation_args(command_parser): ) -def get_command_parser() -> argparse.ArgumentParser: +def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser: # Create invocation parser parser = argparse.ArgumentParser() @@ -76,20 +79,72 @@ def get_command_parser() -> argparse.ArgumentParser: commands = BaseCommand.get_all_subclasses() add_parsers(subparsers, commands, exclude_fields=["type"]) + # Create subparsers for exposed CLI graphs + # TODO: add a way to identify these graphs + text_to_image = services.graph_library.get(default_text_to_image_graph_id) + add_graph_parsers(subparsers, [text_to_image], add_arguments=add_invocation_args) + return parser +class NodeField(): + alias: str + node_path: str + field: str + field_type: type + + def __init__(self, alias: str, node_path: str, field: str, field_type: type): + self.alias = alias + self.node_path = node_path + self.field = field + self.field_type = field_type + + +def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str,NodeField]: + return {k:NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()} + + +def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField: + """Gets the node field for the specified field alias""" + exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias) + node_type = type(graph.graph.get_node(exposed_input.node_path)) + return NodeField(alias=exposed_input.alias, node_path=f'{node_id}.{exposed_input.node_path}', field=exposed_input.field, field_type=get_type_hints(node_type)[exposed_input.field]) + + +def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField: + """Gets the node field for the specified field alias""" + exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias) + node_type = type(graph.graph.get_node(exposed_output.node_path)) + node_output_type = node_type.get_output_type() + return NodeField(alias=exposed_output.alias, node_path=f'{node_id}.{exposed_output.node_path}', field=exposed_output.field, field_type=get_type_hints(node_output_type)[exposed_output.field]) + + +def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]: + """Gets the inputs for the specified invocation from the context""" + node_type = type(invocation) + if node_type is not GraphInvocation: + return fields_from_type_hints(get_type_hints(node_type), invocation.id) + else: + graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id]) + return {e.alias: get_node_input_field(graph, e.alias, invocation.id) for e in graph.exposed_inputs} + + +def get_node_outputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]: + """Gets the outputs for the specified invocation from the context""" + node_type = type(invocation) + if node_type is not GraphInvocation: + return fields_from_type_hints(get_type_hints(node_type.get_output_type()), invocation.id) + else: + graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id]) + return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs} + + def generate_matching_edges( - a: BaseInvocation, b: BaseInvocation + a: BaseInvocation, b: BaseInvocation, context: CliContext ) -> list[Edge]: """Generates all possible edges between two invocations""" - atype = type(a) - btype = type(b) - - aoutputtype = atype.get_output_type() - - afields = get_type_hints(aoutputtype) - bfields = get_type_hints(btype) + afields = get_node_outputs(a, context) + bfields = get_node_inputs(b, context) matching_fields = set(afields.keys()).intersection(bfields.keys()) @@ -98,14 +153,14 @@ def generate_matching_edges( matching_fields = matching_fields.difference(invalid_fields) # Validate types - matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f], bfields[f])] + matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)] edges = [ Edge( - source=EdgeConnection(node_id=a.id, field=field), - destination=EdgeConnection(node_id=b.id, field=field) + source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field), + destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field) ) - for field in matching_fields + for alias in matching_fields ] return edges @@ -158,6 +213,9 @@ def invoke_cli(): latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')), images=DiskImageStorage(f'{output_folder}/images'), queue=MemoryInvocationQueue(), + graph_library=SqliteItemStorage[LibraryGraph]( + filename=db_location, table_name="graphs" + ), graph_execution_manager=SqliteItemStorage[GraphExecutionState]( filename=db_location, table_name="graph_executions" ), @@ -165,9 +223,12 @@ def invoke_cli(): restoration=RestorationServices(config), ) + system_graphs = create_system_graphs(services.graph_library) + system_graph_names = set([g.name for g in system_graphs]) + invoker = Invoker(services) session: GraphExecutionState = invoker.create_execution_state() - parser = get_command_parser() + parser = get_command_parser(services) re_negid = re.compile('^-[0-9]+$') @@ -185,11 +246,12 @@ def invoke_cli(): try: # Refresh the state of the session - history = list(get_graph_execution_history(context.session)) + #history = list(get_graph_execution_history(context.session)) + history = list(reversed(context.nodes_added)) # Split the command for piping cmds = cmd_input.split("|") - start_id = len(history) + start_id = len(context.nodes_added) current_id = start_id new_invocations = list() for cmd in cmds: @@ -205,8 +267,24 @@ def invoke_cli(): args[field_name] = field_default # Parse invocation - args["id"] = current_id - command = CliCommand(command=args) + command: CliCommand = None # type:ignore + system_graph: LibraryGraph|None = None + if args['type'] in system_graph_names: + system_graph = next(filter(lambda g: g.name == args['type'], system_graphs)) + invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id)) + for exposed_input in system_graph.exposed_inputs: + if exposed_input.alias in args: + node = invocation.graph.get_node(exposed_input.node_path) + field = exposed_input.field + setattr(node, field, args[exposed_input.alias]) + command = CliCommand(command = invocation) + context.graph_nodes[invocation.id] = system_graph.id + else: + args["id"] = current_id + command = CliCommand(command=args) + + if command is None: + continue # Run any CLI commands immediately if isinstance(command.command, BaseCommand): @@ -217,6 +295,7 @@ def invoke_cli(): command.command.run(context) continue + # TODO: handle linking with library graphs # Pipe previous command output (if there was a previous command) edges: list[Edge] = list() if len(history) > 0 or current_id != start_id: @@ -229,7 +308,7 @@ def invoke_cli(): else context.session.graph.get_node(from_id) ) matching_edges = generate_matching_edges( - from_node, command.command + from_node, command.command, context ) edges.extend(matching_edges) @@ -242,7 +321,7 @@ def invoke_cli(): link_node = context.session.graph.get_node(node_id) matching_edges = generate_matching_edges( - link_node, command.command + link_node, command.command, context ) matching_destinations = [e.destination for e in matching_edges] edges = [e for e in edges if e.destination not in matching_destinations] @@ -256,12 +335,14 @@ def invoke_cli(): if re_negid.match(node_id): node_id = str(current_id + int(node_id)) + # TODO: handle missing input/output + node_output = get_node_outputs(context.session.graph.get_node(node_id), context)[link[1]] + node_input = get_node_inputs(command.command, context)[link[2]] + edges.append( Edge( - source=EdgeConnection(node_id=node_id, field=link[1]), - destination=EdgeConnection( - node_id=command.command.id, field=link[2] - ) + source=EdgeConnection(node_id=node_output.node_path, field=node_output.field), + destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field) ) ) @@ -270,10 +351,10 @@ def invoke_cli(): current_id = current_id + 1 # Add the node to the session - context.session.add_node(command.command) + context.add_node(command.command) for edge in edges: print(edge) - context.session.add_edge(edge) + context.add_edge(edge) # Execute all remaining nodes invoke_all(context) @@ -285,7 +366,7 @@ def invoke_cli(): except SessionError: # Start a new session print("Session error: creating a new session") - context.session = context.invoker.create_execution_state() + context.reset() except ExitCli: break diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 2da6e451a9..ef17962f89 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -1,5 +1,6 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) +import random from typing import Literal, Optional from pydantic import BaseModel, Field import torch @@ -99,13 +100,17 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c return x +def random_seed(): + return random.randint(0, np.iinfo(np.uint32).max) + + class NoiseInvocation(BaseInvocation): """Generates latent noise.""" type: Literal["noise"] = "noise" # Inputs - seed: int = Field(default=0, ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", ) + seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed) width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", ) height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", ) @@ -313,6 +318,56 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): ) +class LatentsToLatentsInvocation(TextToLatentsInvocation): + """Generates latents using latents as base image.""" + + type: Literal["l2l"] = "l2l" + + # Inputs + latents: Optional[LatentsField] = Field(description="The latents to use as a base image") + strength: float = Field(default=0.5, description="The strength of the latents to use") + + def invoke(self, context: InvocationContext) -> LatentsOutput: + noise = context.services.latents.get(self.noise.latents_name) + latent = context.services.latents.get(self.latents.latents_name) + + def step_callback(state: PipelineIntermediateState): + self.dispatch_progress(context, state) + + model = self.get_model(context.services.model_manager) + conditioning_data = self.get_conditioning_data(model) + + # TODO: Verify the noise is the right size + + initial_latents = latent if self.strength < 1.0 else torch.zeros_like( + latent, device=model.device, dtype=latent.dtype + ) + + timesteps, _ = model.get_img2img_timesteps( + self.steps, + self.strength, + device=model.device, + ) + + result_latents, result_attention_map_saver = model.latents_from_embeddings( + latents=initial_latents, + timesteps=timesteps, + noise=noise, + num_inference_steps=self.steps, + conditioning_data=conditioning_data, + callback=step_callback + ) + + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + torch.cuda.empty_cache() + + name = f'{context.graph_execution_state_id}__{self.id}' + context.services.latents.set(name, result_latents) + return LatentsOutput( + latents=LatentsField(latents_name=name) + ) + + # Latent to image class LatentsToImageInvocation(BaseInvocation): """Generates an image from latents.""" diff --git a/invokeai/app/invocations/params.py b/invokeai/app/invocations/params.py new file mode 100644 index 0000000000..fcc7f1737a --- /dev/null +++ b/invokeai/app/invocations/params.py @@ -0,0 +1,18 @@ +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) + +from typing import Literal +from pydantic import Field +from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext +from .math import IntOutput + +# Pass-through parameter nodes - used by subgraphs + +class ParamIntInvocation(BaseInvocation): + """An integer parameter""" + #fmt: off + type: Literal["param_int"] = "param_int" + a: int = Field(default=0, description="The integer value") + #fmt: on + + def invoke(self, context: InvocationContext) -> IntOutput: + return IntOutput(a=self.a) diff --git a/invokeai/app/services/default_graphs.py b/invokeai/app/services/default_graphs.py new file mode 100644 index 0000000000..637d906e75 --- /dev/null +++ b/invokeai/app/services/default_graphs.py @@ -0,0 +1,56 @@ +from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation +from ..invocations.params import ParamIntInvocation +from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph +from .item_storage import ItemStorageABC + + +default_text_to_image_graph_id = '539b2af5-2b4d-4d8c-8071-e54a3255fc74' + + +def create_text_to_image() -> LibraryGraph: + return LibraryGraph( + id=default_text_to_image_graph_id, + name='t2i', + description='Converts text to an image', + graph=Graph( + nodes={ + 'width': ParamIntInvocation(id='width', a=512), + 'height': ParamIntInvocation(id='height', a=512), + '3': NoiseInvocation(id='3'), + '4': TextToLatentsInvocation(id='4'), + '5': LatentsToImageInvocation(id='5') + }, + edges=[ + Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')), + Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')), + Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='4', field='width')), + Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='4', field='height')), + Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='4', field='noise')), + Edge(source=EdgeConnection(node_id='4', field='latents'), destination=EdgeConnection(node_id='5', field='latents')), + ] + ), + exposed_inputs=[ + ExposedNodeInput(node_path='4', field='prompt', alias='prompt'), + ExposedNodeInput(node_path='width', field='a', alias='width'), + ExposedNodeInput(node_path='height', field='a', alias='height') + ], + exposed_outputs=[ + ExposedNodeOutput(node_path='5', field='image', alias='image') + ]) + + +def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]: + """Creates the default system graphs, or adds new versions if the old ones don't match""" + + graphs: list[LibraryGraph] = list() + + text_to_image = graph_library.get(default_text_to_image_graph_id) + + # TODO: Check if the graph is the same as the default one, and if not, update it + #if text_to_image is None: + text_to_image = create_text_to_image() + graph_library.set(text_to_image) + + graphs.append(text_to_image) + + return graphs diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py index e286569bcc..44f6a3d69e 100644 --- a/invokeai/app/services/graph.py +++ b/invokeai/app/services/graph.py @@ -17,7 +17,7 @@ from typing import ( ) import networkx as nx -from pydantic import BaseModel, validator +from pydantic import BaseModel, root_validator, validator from pydantic.fields import Field from ..invocations import * @@ -283,7 +283,8 @@ class Graph(BaseModel): :raises InvalidEdgeError: the provided edge is invalid. """ - if self._is_edge_valid(edge) and edge not in self.edges: + self._validate_edge(edge) + if edge not in self.edges: self.edges.append(edge) else: raise InvalidEdgeError() @@ -354,7 +355,7 @@ class Graph(BaseModel): return True - def _is_edge_valid(self, edge: Edge) -> bool: + def _validate_edge(self, edge: Edge): """Validates that a new edge doesn't create a cycle in the graph""" # Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly) @@ -362,54 +363,53 @@ class Graph(BaseModel): from_node = self.get_node(edge.source.node_id) to_node = self.get_node(edge.destination.node_id) except NodeNotFoundError: - return False + raise InvalidEdgeError("One or both nodes don't exist") # Validate that an edge to this node+field doesn't already exist input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field) if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation): - return False + raise InvalidEdgeError(f'Edge to node {edge.destination.node_id} field {edge.destination.field} already exists') # Validate that no cycles would be created g = self.nx_graph_flat() g.add_edge(edge.source.node_id, edge.destination.node_id) if not nx.is_directed_acyclic_graph(g): - return False + raise InvalidEdgeError(f'Edge creates a cycle in the graph') # Validate that the field types are compatible if not are_connections_compatible( from_node, edge.source.field, to_node, edge.destination.field ): - return False + raise InvalidEdgeError(f'Fields are incompatible') # Validate if iterator output type matches iterator input type (if this edge results in both being set) if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection": if not self._is_iterator_connection_valid( edge.destination.node_id, new_input=edge.source ): - return False + raise InvalidEdgeError(f'Iterator input type does not match iterator output type') # Validate if iterator input type matches output type (if this edge results in both being set) if isinstance(from_node, IterateInvocation) and edge.source.field == "item": if not self._is_iterator_connection_valid( edge.source.node_id, new_output=edge.destination ): - return False + raise InvalidEdgeError(f'Iterator output type does not match iterator input type') # Validate if collector input type matches output type (if this edge results in both being set) if isinstance(to_node, CollectInvocation) and edge.destination.field == "item": if not self._is_collector_connection_valid( edge.destination.node_id, new_input=edge.source ): - return False + raise InvalidEdgeError(f'Collector output type does not match collector input type') # Validate if collector output type matches input type (if this edge results in both being set) if isinstance(from_node, CollectInvocation) and edge.source.field == "collection": if not self._is_collector_connection_valid( edge.source.node_id, new_output=edge.destination ): - return False + raise InvalidEdgeError(f'Collector input type does not match collector output type') - return True def has_node(self, node_path: str) -> bool: """Determines whether or not a node exists in the graph.""" @@ -733,7 +733,7 @@ class Graph(BaseModel): for sgn in ( gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation) ): - sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix)) + g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix)) # TODO: figure out if iteration nodes need to be expanded @@ -858,7 +858,8 @@ class GraphExecutionState(BaseModel): def is_complete(self) -> bool: """Returns true if the graph is complete""" - return self.has_error() or all((k in self.executed for k in self.graph.nodes)) + node_ids = set(self.graph.nx_graph_flat().nodes) + return self.has_error() or all((k in self.executed for k in node_ids)) def has_error(self) -> bool: """Returns true if the graph has any errors""" @@ -946,11 +947,11 @@ class GraphExecutionState(BaseModel): def _iterator_graph(self) -> nx.DiGraph: """Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node""" - g = self.graph.nx_graph() + g = self.graph.nx_graph_flat() collectors = ( n for n in self.graph.nodes - if isinstance(self.graph.nodes[n], CollectInvocation) + if isinstance(self.graph.get_node(n), CollectInvocation) ) for c in collectors: g.remove_edges_from(list(g.in_edges(c))) @@ -962,7 +963,7 @@ class GraphExecutionState(BaseModel): iterators = [ n for n in nx.ancestors(g, node_id) - if isinstance(self.graph.nodes[n], IterateInvocation) + if isinstance(self.graph.get_node(n), IterateInvocation) ] return iterators @@ -1098,7 +1099,9 @@ class GraphExecutionState(BaseModel): # TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state def _is_edge_valid(self, edge: Edge) -> bool: - if not self._is_edge_valid(edge): + try: + self.graph._validate_edge(edge) + except InvalidEdgeError: return False # Invalid if destination has already been prepared or executed @@ -1144,4 +1147,52 @@ class GraphExecutionState(BaseModel): self.graph.delete_edge(edge) +class ExposedNodeInput(BaseModel): + node_path: str = Field(description="The node path to the node with the input") + field: str = Field(description="The field name of the input") + alias: str = Field(description="The alias of the input") + + +class ExposedNodeOutput(BaseModel): + node_path: str = Field(description="The node path to the node with the output") + field: str = Field(description="The field name of the output") + alias: str = Field(description="The alias of the output") + +class LibraryGraph(BaseModel): + id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4) + graph: Graph = Field(description="The graph") + name: str = Field(description="The name of the graph") + description: str = Field(description="The description of the graph") + exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list) + exposed_outputs: list[ExposedNodeOutput] = Field(description="The outputs exposed by this graph", default_factory=list) + + @validator('exposed_inputs', 'exposed_outputs') + def validate_exposed_aliases(cls, v): + if len(v) != len(set(i.alias for i in v)): + raise ValueError("Duplicate exposed alias") + return v + + @root_validator + def validate_exposed_nodes(cls, values): + graph = values['graph'] + + # Validate exposed inputs + for exposed_input in values['exposed_inputs']: + if not graph.has_node(exposed_input.node_path): + raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist") + node = graph.get_node(exposed_input.node_path) + if get_input_field(node, exposed_input.field) is None: + raise ValueError(f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}") + + # Validate exposed outputs + for exposed_output in values['exposed_outputs']: + if not graph.has_node(exposed_output.node_path): + raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist") + node = graph.get_node(exposed_output.node_path) + if get_output_field(node, exposed_output.field) is None: + raise ValueError(f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}") + + return values + + GraphInvocation.update_forward_refs() diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 2cd0f55fd9..c3c6bbce7e 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -19,6 +19,7 @@ class InvocationServices: restoration: RestorationServices # NOTE: we must forward-declare any types that include invocations, since invocations can use services + graph_library: ItemStorageABC["LibraryGraph"] graph_execution_manager: ItemStorageABC["GraphExecutionState"] processor: "InvocationProcessorABC" @@ -29,6 +30,7 @@ class InvocationServices: latents: LatentsStorageBase, images: ImageStorageBase, queue: InvocationQueueABC, + graph_library: ItemStorageABC["LibraryGraph"], graph_execution_manager: ItemStorageABC["GraphExecutionState"], processor: "InvocationProcessorABC", restoration: RestorationServices, @@ -38,6 +40,7 @@ class InvocationServices: self.latents = latents self.images = images self.queue = queue + self.graph_library = graph_library self.graph_execution_manager = graph_execution_manager self.processor = processor self.restoration = restoration diff --git a/invokeai/app/services/sqlite.py b/invokeai/app/services/sqlite.py index fd089014bb..e06ca8c1ac 100644 --- a/invokeai/app/services/sqlite.py +++ b/invokeai/app/services/sqlite.py @@ -35,8 +35,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): self._create_table() def _create_table(self): - try: - self._lock.acquire() + with self._lock: self._cursor.execute( f"""CREATE TABLE IF NOT EXISTS {self._table_name} ( item TEXT, @@ -45,34 +44,27 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): self._cursor.execute( f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);""" ) - finally: - self._lock.release() + self._conn.commit() def _parse_item(self, item: str) -> T: item_type = get_args(self.__orig_class__)[0] return parse_raw_as(item_type, item) def set(self, item: T): - try: - self._lock.acquire() + with self._lock: self._cursor.execute( f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""", (item.json(),), ) self._conn.commit() - finally: - self._lock.release() self._on_changed(item) def get(self, id: str) -> Union[T, None]: - try: - self._lock.acquire() + with self._lock: self._cursor.execute( f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),) ) result = self._cursor.fetchone() - finally: - self._lock.release() if not result: return None @@ -80,19 +72,15 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): return self._parse_item(result[0]) def delete(self, id: str): - try: - self._lock.acquire() + with self._lock: self._cursor.execute( f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),) ) self._conn.commit() - finally: - self._lock.release() self._on_deleted(id) def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: - try: - self._lock.acquire() + with self._lock: self._cursor.execute( f"""SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;""", (per_page, page * per_page), @@ -103,8 +91,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""") count = self._cursor.fetchone()[0] - finally: - self._lock.release() pageCount = int(count / per_page) + 1 @@ -115,8 +101,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): def search( self, query: str, page: int = 0, per_page: int = 10 ) -> PaginatedResults[T]: - try: - self._lock.acquire() + with self._lock: self._cursor.execute( f"""SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;""", (f"%{query}%", per_page, page * per_page), @@ -130,8 +115,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): (f"%{query}%",), ) count = self._cursor.fetchone()[0] - finally: - self._lock.release() pageCount = int(count / per_page) + 1 diff --git a/pyproject.toml b/pyproject.toml index 3d72483237..ec6aabfb8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "compel==1.0.5", "datasets", - "diffusers[torch]~=0.14", + "diffusers[torch]==0.14", "dnspython==2.2.1", "einops", "eventlet", diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index 506b8653f8..f65129797e 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -7,7 +7,7 @@ from invokeai.app.services.processor import DefaultInvocationProcessor from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory from invokeai.app.services.invocation_queue import MemoryInvocationQueue from invokeai.app.services.invocation_services import InvocationServices -from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState +from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, LibraryGraph, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState import pytest @@ -28,6 +28,9 @@ def mock_services(): images = None, # type: ignore latents = None, # type: ignore queue = MemoryInvocationQueue(), + graph_library=SqliteItemStorage[LibraryGraph]( + filename=sqlite_memory, table_name="graphs" + ), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), processor = DefaultInvocationProcessor(), restoration = None, # type: ignore diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index 68df708bdd..46d532b9f7 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -5,7 +5,7 @@ from invokeai.app.services.invocation_queue import MemoryInvocationQueue from invokeai.app.services.invoker import Invoker from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from invokeai.app.services.invocation_services import InvocationServices -from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState +from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, LibraryGraph, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState import pytest @@ -26,6 +26,9 @@ def mock_services() -> InvocationServices: images = None, # type: ignore latents = None, # type: ignore queue = MemoryInvocationQueue(), + graph_library=SqliteItemStorage[LibraryGraph]( + filename=sqlite_memory, table_name="graphs" + ), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), processor = DefaultInvocationProcessor(), restoration = None, # type: ignore diff --git a/tests/nodes/test_node_graph.py b/tests/nodes/test_node_graph.py index b864e1e47a..c7693b59c9 100644 --- a/tests/nodes/test_node_graph.py +++ b/tests/nodes/test_node_graph.py @@ -1,9 +1,11 @@ -from invokeai.app.invocations.image import * - from .test_nodes import ListPassThroughInvocation, PromptTestInvocation from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation from invokeai.app.invocations.upscale import UpscaleInvocation +from invokeai.app.invocations.image import * +from invokeai.app.invocations.math import AddInvocation, SubtractInvocation +from invokeai.app.invocations.params import ParamIntInvocation +from invokeai.app.services.default_graphs import create_text_to_image import pytest @@ -417,6 +419,66 @@ def test_graph_gets_subgraph_node(): assert result.id == '1' assert result == n1_1 + +def test_graph_expands_subgraph(): + g = Graph() + n1 = GraphInvocation(id = "1") + n1.graph = Graph() + + n1_1 = AddInvocation(id = "1", a = 1, b = 2) + n1_2 = SubtractInvocation(id = "2", b = 3) + n1.graph.add_node(n1_1) + n1.graph.add_node(n1_2) + n1.graph.add_edge(create_edge("1","a","2","a")) + + g.add_node(n1) + + n2 = AddInvocation(id = "2", b = 5) + g.add_node(n2) + g.add_edge(create_edge("1.2","a","2","a")) + + dg = g.nx_graph_flat() + assert set(dg.nodes) == set(['1.1', '1.2', '2']) + assert set(dg.edges) == set([('1.1', '1.2'), ('1.2', '2')]) + + +def test_graph_subgraph_t2i(): + g = Graph() + n1 = GraphInvocation(id = "1") + + # Get text to image default graph + lg = create_text_to_image() + n1.graph = lg.graph + + g.add_node(n1) + + n2 = ParamIntInvocation(id = "2", a = 512) + n3 = ParamIntInvocation(id = "3", a = 256) + + g.add_node(n2) + g.add_node(n3) + + g.add_edge(create_edge("2","a","1.width","a")) + g.add_edge(create_edge("3","a","1.height","a")) + + n4 = ShowImageInvocation(id = "4") + g.add_node(n4) + g.add_edge(create_edge("1.5","image","4","image")) + + # Validate + dg = g.nx_graph_flat() + assert set(dg.nodes) == set(['1.width', '1.height', '1.3', '1.4', '1.5', '2', '3', '4']) + expected_edges = [(f'1.{e.source.node_id}',f'1.{e.destination.node_id}') for e in lg.graph.edges] + expected_edges.extend([ + ('2','1.width'), + ('3','1.height'), + ('1.5','4') + ]) + print(expected_edges) + print(list(dg.edges)) + assert set(dg.edges) == set(expected_edges) + + def test_graph_fails_to_get_missing_subgraph_node(): g = Graph() n1 = GraphInvocation(id = "1")