diff --git a/invokeai/app/invocations/prompt.py b/invokeai/app/invocations/prompt.py index 0c7e3069df..9af87e1ed4 100644 --- a/invokeai/app/invocations/prompt.py +++ b/invokeai/app/invocations/prompt.py @@ -2,8 +2,8 @@ from typing import Literal from pydantic.fields import Field -from .baseinvocation import BaseInvocationOutput - +from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext +from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator class PromptOutput(BaseInvocationOutput): """Base class for invocations that output a prompt""" @@ -20,3 +20,38 @@ class PromptOutput(BaseInvocationOutput): 'prompt', ] } + + +class PromptCollectionOutput(BaseInvocationOutput): + """Base class for invocations that output a collection of prompts""" + + # fmt: off + type: Literal["prompt_collection_output"] = "prompt_collection_output" + + prompt_collection: list[str] = Field(description="The output prompt collection") + count: int = Field(description="The size of the prompt collection") + # fmt: on + + class Config: + schema_extra = {"required": ["type", "prompt_collection", "count"]} + + +class DynamicPromptInvocation(BaseInvocation): + """Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator""" + + type: Literal["dynamic_prompt"] = "dynamic_prompt" + prompt: str = Field(description="The prompt to parse with dynamicprompts") + max_prompts: int = Field(default=1, description="The number of prompts to generate") + combinatorial: bool = Field( + default=False, description="Whether to use the combinatorial generator" + ) + + def invoke(self, context: InvocationContext) -> PromptCollectionOutput: + if self.combinatorial: + generator = CombinatorialPromptGenerator() + prompts = generator.generate(self.prompt, max_prompts=self.max_prompts) + else: + generator = RandomPromptGenerator() + prompts = generator.generate(self.prompt, num_images=self.max_prompts) + + return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts)) diff --git a/invokeai/app/services/image_file_storage.py b/invokeai/app/services/image_file_storage.py index aeacfe3f1c..b90b9b2f8b 100644 --- a/invokeai/app/services/image_file_storage.py +++ b/invokeai/app/services/image_file_storage.py @@ -1,5 +1,4 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team -import os from abc import ABC, abstractmethod from pathlib import Path from queue import Queue @@ -76,28 +75,26 @@ class ImageFileStorageBase(ABC): class DiskImageFileStorage(ImageFileStorageBase): """Stores images on disk""" - __output_folder: str + __output_folder: Path __cache_ids: Queue # TODO: this is an incredibly naive cache - __cache: Dict[str, PILImageType] + __cache: Dict[Path, PILImageType] __max_cache_size: int - def __init__(self, output_folder: str): - self.__output_folder = output_folder + def __init__(self, output_folder: str | Path): self.__cache = dict() self.__cache_ids = Queue() self.__max_cache_size = 10 # TODO: get this from config - Path(output_folder).mkdir(parents=True, exist_ok=True) + self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder) + self.__thumbnails_folder = self.__output_folder / 'thumbnails' - # TODO: don't hard-code. get/save/delete should maybe take subpath? - Path(os.path.join(output_folder)).mkdir(parents=True, exist_ok=True) - Path(os.path.join(output_folder, "thumbnails")).mkdir( - parents=True, exist_ok=True - ) + # Validate required output folders at launch + self.__validate_storage_folders() def get(self, image_name: str) -> PILImageType: try: image_path = self.get_path(image_name) + cache_item = self.__get_cache(image_path) if cache_item: return cache_item @@ -116,6 +113,7 @@ class DiskImageFileStorage(ImageFileStorageBase): thumbnail_size: int = 256, ) -> None: try: + self.__validate_storage_folders() image_path = self.get_path(image_name) if metadata is not None: @@ -137,10 +135,9 @@ class DiskImageFileStorage(ImageFileStorageBase): def delete(self, image_name: str) -> None: try: - basename = os.path.basename(image_name) - image_path = self.get_path(basename) + image_path = self.get_path(image_name) - if os.path.exists(image_path): + if image_path.exists(): send2trash(image_path) if image_path in self.__cache: del self.__cache[image_path] @@ -148,7 +145,7 @@ class DiskImageFileStorage(ImageFileStorageBase): thumbnail_name = get_thumbnail_name(image_name) thumbnail_path = self.get_path(thumbnail_name, True) - if os.path.exists(thumbnail_path): + if thumbnail_path.exists(): send2trash(thumbnail_path) if thumbnail_path in self.__cache: del self.__cache[thumbnail_path] @@ -156,41 +153,33 @@ class DiskImageFileStorage(ImageFileStorageBase): raise ImageFileDeleteException from e # TODO: make this a bit more flexible for e.g. cloud storage - def get_path(self, image_name: str, thumbnail: bool = False) -> str: - # strip out any relative path shenanigans - basename = os.path.basename(image_name) - + def get_path(self, image_name: str, thumbnail: bool = False) -> Path: + path = self.__output_folder / image_name + if thumbnail: - thumbnail_name = get_thumbnail_name(basename) - path = os.path.join( - self.__output_folder, - "thumbnails", - thumbnail_name, - ) - else: - path = os.path.join(self.__output_folder, basename) + thumbnail_name = get_thumbnail_name(image_name) + path = self.__thumbnails_folder / thumbnail_name - abspath = os.path.abspath(path) + return path - return abspath - - def validate_path(self, path: str) -> bool: + def validate_path(self, path: str | Path) -> bool: """Validates the path given for an image or thumbnail.""" - try: - os.stat(path) - return True - except: - return False + path = path if isinstance(path, Path) else Path(path) + return path.exists() + + def __validate_storage_folders(self) -> None: + """Checks if the required output folders exist and create them if they don't""" + folders: list[Path] = [self.__output_folder, self.__thumbnails_folder] + for folder in folders: + folder.mkdir(parents=True, exist_ok=True) - def __get_cache(self, image_name: str) -> PILImageType | None: + def __get_cache(self, image_name: Path) -> PILImageType | None: return None if image_name not in self.__cache else self.__cache[image_name] - def __set_cache(self, image_name: str, image: PILImageType): + def __set_cache(self, image_name: Path, image: PILImageType): if not image_name in self.__cache: self.__cache[image_name] = image - self.__cache_ids.put( - image_name - ) # TODO: this should refresh position for LRU cache + self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache if len(self.__cache) > self.__max_cache_size: cache_id = self.__cache_ids.get() if cache_id in self.__cache: diff --git a/invokeai/app/services/latent_storage.py b/invokeai/app/services/latent_storage.py index 519c254087..17d35d7c33 100644 --- a/invokeai/app/services/latent_storage.py +++ b/invokeai/app/services/latent_storage.py @@ -1,6 +1,5 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) -import os from abc import ABC, abstractmethod from pathlib import Path from queue import Queue @@ -70,24 +69,26 @@ class ForwardCacheLatentsStorage(LatentsStorageBase): class DiskLatentsStorage(LatentsStorageBase): """Stores latents in a folder on disk without caching""" - __output_folder: str + __output_folder: str | Path - def __init__(self, output_folder: str): - self.__output_folder = output_folder - Path(output_folder).mkdir(parents=True, exist_ok=True) + def __init__(self, output_folder: str | Path): + self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder) + self.__output_folder.mkdir(parents=True, exist_ok=True) def get(self, name: str) -> torch.Tensor: latent_path = self.get_path(name) return torch.load(latent_path) def save(self, name: str, data: torch.Tensor) -> None: + self.__output_folder.mkdir(parents=True, exist_ok=True) latent_path = self.get_path(name) torch.save(data, latent_path) def delete(self, name: str) -> None: latent_path = self.get_path(name) - os.remove(latent_path) + latent_path.unlink() + - def get_path(self, name: str) -> str: - return os.path.join(self.__output_folder, name) + def get_path(self, name: str) -> Path: + return self.__output_folder / name \ No newline at end of file diff --git a/invokeai/frontend/web/src/services/api/index.ts b/invokeai/frontend/web/src/services/api/index.ts index 6d1dbed780..cd83555f15 100644 --- a/invokeai/frontend/web/src/services/api/index.ts +++ b/invokeai/frontend/web/src/services/api/index.ts @@ -24,6 +24,7 @@ export type { CreateModelRequest } from './models/CreateModelRequest'; export type { CvInpaintInvocation } from './models/CvInpaintInvocation'; export type { DiffusersModelInfo } from './models/DiffusersModelInfo'; export type { DivideInvocation } from './models/DivideInvocation'; +export type { DynamicPromptInvocation } from './models/DynamicPromptInvocation'; export type { Edge } from './models/Edge'; export type { EdgeConnection } from './models/EdgeConnection'; export type { FloatCollectionOutput } from './models/FloatCollectionOutput'; @@ -86,6 +87,7 @@ export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedRe export type { ParamFloatInvocation } from './models/ParamFloatInvocation'; export type { ParamIntInvocation } from './models/ParamIntInvocation'; export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation'; +export type { PromptCollectionOutput } from './models/PromptCollectionOutput'; export type { PromptOutput } from './models/PromptOutput'; export type { RandomIntInvocation } from './models/RandomIntInvocation'; export type { RandomRangeInvocation } from './models/RandomRangeInvocation'; diff --git a/invokeai/frontend/web/src/services/api/models/DynamicPromptInvocation.ts b/invokeai/frontend/web/src/services/api/models/DynamicPromptInvocation.ts new file mode 100644 index 0000000000..f7323a489b --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/DynamicPromptInvocation.ts @@ -0,0 +1,31 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator + */ +export type DynamicPromptInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'dynamic_prompt'; + /** + * The prompt to parse with dynamicprompts + */ + prompt: string; + /** + * The number of prompts to generate + */ + max_prompts?: number; + /** + * Whether to use the combinatorial generator + */ + combinatorial?: boolean; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/Graph.ts b/invokeai/frontend/web/src/services/api/models/Graph.ts index fbc03dcf3a..efac5dabcc 100644 --- a/invokeai/frontend/web/src/services/api/models/Graph.ts +++ b/invokeai/frontend/web/src/services/api/models/Graph.ts @@ -10,6 +10,7 @@ import type { ContentShuffleImageProcessorInvocation } from './ContentShuffleIma import type { ControlNetInvocation } from './ControlNetInvocation'; import type { CvInpaintInvocation } from './CvInpaintInvocation'; import type { DivideInvocation } from './DivideInvocation'; +import type { DynamicPromptInvocation } from './DynamicPromptInvocation'; import type { Edge } from './Edge'; import type { FloatLinearRangeInvocation } from './FloatLinearRangeInvocation'; import type { GraphInvocation } from './GraphInvocation'; @@ -71,7 +72,7 @@ export type Graph = { /** * The nodes in this graph */ - nodes?: Record; + nodes?: Record; /** * The connections between nodes and their fields in this graph */ diff --git a/invokeai/frontend/web/src/services/api/models/GraphExecutionState.ts b/invokeai/frontend/web/src/services/api/models/GraphExecutionState.ts index ea41ce055b..ccd5d6f499 100644 --- a/invokeai/frontend/web/src/services/api/models/GraphExecutionState.ts +++ b/invokeai/frontend/web/src/services/api/models/GraphExecutionState.ts @@ -16,6 +16,7 @@ import type { IterateInvocationOutput } from './IterateInvocationOutput'; import type { LatentsOutput } from './LatentsOutput'; import type { MaskOutput } from './MaskOutput'; import type { NoiseOutput } from './NoiseOutput'; +import type { PromptCollectionOutput } from './PromptCollectionOutput'; import type { PromptOutput } from './PromptOutput'; /** @@ -45,7 +46,7 @@ export type GraphExecutionState = { /** * The results of node executions */ - results: Record; + results: Record; /** * Errors raised when executing nodes */ diff --git a/invokeai/frontend/web/src/services/api/models/PromptCollectionOutput.ts b/invokeai/frontend/web/src/services/api/models/PromptCollectionOutput.ts new file mode 100644 index 0000000000..4444ab4d33 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/PromptCollectionOutput.ts @@ -0,0 +1,19 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * Base class for invocations that output a collection of prompts + */ +export type PromptCollectionOutput = { + type: 'prompt_collection_output'; + /** + * The output prompt collection + */ + prompt_collection: Array; + /** + * The size of the prompt collection + */ + count: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/services/SessionsService.ts b/invokeai/frontend/web/src/services/api/services/SessionsService.ts index b95f0526be..d850a1ed38 100644 --- a/invokeai/frontend/web/src/services/api/services/SessionsService.ts +++ b/invokeai/frontend/web/src/services/api/services/SessionsService.ts @@ -9,6 +9,7 @@ import type { ContentShuffleImageProcessorInvocation } from '../models/ContentSh import type { ControlNetInvocation } from '../models/ControlNetInvocation'; import type { CvInpaintInvocation } from '../models/CvInpaintInvocation'; import type { DivideInvocation } from '../models/DivideInvocation'; +import type { DynamicPromptInvocation } from '../models/DynamicPromptInvocation'; import type { Edge } from '../models/Edge'; import type { FloatLinearRangeInvocation } from '../models/FloatLinearRangeInvocation'; import type { Graph } from '../models/Graph'; @@ -173,7 +174,7 @@ export class SessionsService { * The id of the session */ sessionId: string, - requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation), + requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation), }): CancelablePromise { return __request(OpenAPI, { method: 'POST', @@ -210,7 +211,7 @@ export class SessionsService { * The path to the node in the graph */ nodePath: string, - requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation), + requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation), }): CancelablePromise { return __request(OpenAPI, { method: 'PUT', diff --git a/pyproject.toml b/pyproject.toml index d0ad510b0e..70a87359a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ "datasets", "diffusers[torch]~=0.17.0", "dnspython==2.2.1", + "dynamicprompts", "easing-functions", "einops", "eventlet",