Merge branch 'main' into lstein/new-model-manager

This commit is contained in:
Lincoln Stein 2023-06-14 08:41:40 -07:00 committed by GitHub
commit 82c2498043
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 136 additions and 55 deletions

View File

@ -2,8 +2,8 @@ from typing import Literal
from pydantic.fields import Field
from .baseinvocation import BaseInvocationOutput
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
class PromptOutput(BaseInvocationOutput):
"""Base class for invocations that output a prompt"""
@ -20,3 +20,38 @@ class PromptOutput(BaseInvocationOutput):
'prompt',
]
}
class PromptCollectionOutput(BaseInvocationOutput):
"""Base class for invocations that output a collection of prompts"""
# fmt: off
type: Literal["prompt_collection_output"] = "prompt_collection_output"
prompt_collection: list[str] = Field(description="The output prompt collection")
count: int = Field(description="The size of the prompt collection")
# fmt: on
class Config:
schema_extra = {"required": ["type", "prompt_collection", "count"]}
class DynamicPromptInvocation(BaseInvocation):
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
type: Literal["dynamic_prompt"] = "dynamic_prompt"
prompt: str = Field(description="The prompt to parse with dynamicprompts")
max_prompts: int = Field(default=1, description="The number of prompts to generate")
combinatorial: bool = Field(
default=False, description="Whether to use the combinatorial generator"
)
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
if self.combinatorial:
generator = CombinatorialPromptGenerator()
prompts = generator.generate(self.prompt, max_prompts=self.max_prompts)
else:
generator = RandomPromptGenerator()
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))

View File

@ -1,5 +1,4 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
import os
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
@ -76,28 +75,26 @@ class ImageFileStorageBase(ABC):
class DiskImageFileStorage(ImageFileStorageBase):
"""Stores images on disk"""
__output_folder: str
__output_folder: Path
__cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[str, PILImageType]
__cache: Dict[Path, PILImageType]
__max_cache_size: int
def __init__(self, output_folder: str):
self.__output_folder = output_folder
def __init__(self, output_folder: str | Path):
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config
Path(output_folder).mkdir(parents=True, exist_ok=True)
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__thumbnails_folder = self.__output_folder / 'thumbnails'
# TODO: don't hard-code. get/save/delete should maybe take subpath?
Path(os.path.join(output_folder)).mkdir(parents=True, exist_ok=True)
Path(os.path.join(output_folder, "thumbnails")).mkdir(
parents=True, exist_ok=True
)
# Validate required output folders at launch
self.__validate_storage_folders()
def get(self, image_name: str) -> PILImageType:
try:
image_path = self.get_path(image_name)
cache_item = self.__get_cache(image_path)
if cache_item:
return cache_item
@ -116,6 +113,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
thumbnail_size: int = 256,
) -> None:
try:
self.__validate_storage_folders()
image_path = self.get_path(image_name)
if metadata is not None:
@ -137,10 +135,9 @@ class DiskImageFileStorage(ImageFileStorageBase):
def delete(self, image_name: str) -> None:
try:
basename = os.path.basename(image_name)
image_path = self.get_path(basename)
image_path = self.get_path(image_name)
if os.path.exists(image_path):
if image_path.exists():
send2trash(image_path)
if image_path in self.__cache:
del self.__cache[image_path]
@ -148,7 +145,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(thumbnail_name, True)
if os.path.exists(thumbnail_path):
if thumbnail_path.exists():
send2trash(thumbnail_path)
if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path]
@ -156,41 +153,33 @@ class DiskImageFileStorage(ImageFileStorageBase):
raise ImageFileDeleteException from e
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
def get_path(self, image_name: str, thumbnail: bool = False) -> Path:
path = self.__output_folder / image_name
if thumbnail:
thumbnail_name = get_thumbnail_name(basename)
path = os.path.join(
self.__output_folder,
"thumbnails",
thumbnail_name,
)
else:
path = os.path.join(self.__output_folder, basename)
thumbnail_name = get_thumbnail_name(image_name)
path = self.__thumbnails_folder / thumbnail_name
abspath = os.path.abspath(path)
return path
return abspath
def validate_path(self, path: str) -> bool:
def validate_path(self, path: str | Path) -> bool:
"""Validates the path given for an image or thumbnail."""
try:
os.stat(path)
return True
except:
return False
path = path if isinstance(path, Path) else Path(path)
return path.exists()
def __validate_storage_folders(self) -> None:
"""Checks if the required output folders exist and create them if they don't"""
folders: list[Path] = [self.__output_folder, self.__thumbnails_folder]
for folder in folders:
folder.mkdir(parents=True, exist_ok=True)
def __get_cache(self, image_name: str) -> PILImageType | None:
def __get_cache(self, image_name: Path) -> PILImageType | None:
return None if image_name not in self.__cache else self.__cache[image_name]
def __set_cache(self, image_name: str, image: PILImageType):
def __set_cache(self, image_name: Path, image: PILImageType):
if not image_name in self.__cache:
self.__cache[image_name] = image
self.__cache_ids.put(
image_name
) # TODO: this should refresh position for LRU cache
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
if len(self.__cache) > self.__max_cache_size:
cache_id = self.__cache_ids.get()
if cache_id in self.__cache:

View File

@ -1,6 +1,5 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import os
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
@ -70,24 +69,26 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
class DiskLatentsStorage(LatentsStorageBase):
"""Stores latents in a folder on disk without caching"""
__output_folder: str
__output_folder: str | Path
def __init__(self, output_folder: str):
self.__output_folder = output_folder
Path(output_folder).mkdir(parents=True, exist_ok=True)
def __init__(self, output_folder: str | Path):
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__output_folder.mkdir(parents=True, exist_ok=True)
def get(self, name: str) -> torch.Tensor:
latent_path = self.get_path(name)
return torch.load(latent_path)
def save(self, name: str, data: torch.Tensor) -> None:
self.__output_folder.mkdir(parents=True, exist_ok=True)
latent_path = self.get_path(name)
torch.save(data, latent_path)
def delete(self, name: str) -> None:
latent_path = self.get_path(name)
os.remove(latent_path)
latent_path.unlink()
def get_path(self, name: str) -> str:
return os.path.join(self.__output_folder, name)
def get_path(self, name: str) -> Path:
return self.__output_folder / name

View File

@ -24,6 +24,7 @@ export type { CreateModelRequest } from './models/CreateModelRequest';
export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
export type { DiffusersModelInfo } from './models/DiffusersModelInfo';
export type { DivideInvocation } from './models/DivideInvocation';
export type { DynamicPromptInvocation } from './models/DynamicPromptInvocation';
export type { Edge } from './models/Edge';
export type { EdgeConnection } from './models/EdgeConnection';
export type { FloatCollectionOutput } from './models/FloatCollectionOutput';
@ -86,6 +87,7 @@ export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedRe
export type { ParamFloatInvocation } from './models/ParamFloatInvocation';
export type { ParamIntInvocation } from './models/ParamIntInvocation';
export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation';
export type { PromptCollectionOutput } from './models/PromptCollectionOutput';
export type { PromptOutput } from './models/PromptOutput';
export type { RandomIntInvocation } from './models/RandomIntInvocation';
export type { RandomRangeInvocation } from './models/RandomRangeInvocation';

View File

@ -0,0 +1,31 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator
*/
export type DynamicPromptInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'dynamic_prompt';
/**
* The prompt to parse with dynamicprompts
*/
prompt: string;
/**
* The number of prompts to generate
*/
max_prompts?: number;
/**
* Whether to use the combinatorial generator
*/
combinatorial?: boolean;
};

View File

@ -10,6 +10,7 @@ import type { ContentShuffleImageProcessorInvocation } from './ContentShuffleIma
import type { ControlNetInvocation } from './ControlNetInvocation';
import type { CvInpaintInvocation } from './CvInpaintInvocation';
import type { DivideInvocation } from './DivideInvocation';
import type { DynamicPromptInvocation } from './DynamicPromptInvocation';
import type { Edge } from './Edge';
import type { FloatLinearRangeInvocation } from './FloatLinearRangeInvocation';
import type { GraphInvocation } from './GraphInvocation';
@ -71,7 +72,7 @@ export type Graph = {
/**
* The nodes in this graph
*/
nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation)>;
nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation)>;
/**
* The connections between nodes and their fields in this graph
*/

View File

@ -16,6 +16,7 @@ import type { IterateInvocationOutput } from './IterateInvocationOutput';
import type { LatentsOutput } from './LatentsOutput';
import type { MaskOutput } from './MaskOutput';
import type { NoiseOutput } from './NoiseOutput';
import type { PromptCollectionOutput } from './PromptCollectionOutput';
import type { PromptOutput } from './PromptOutput';
/**
@ -45,7 +46,7 @@ export type GraphExecutionState = {
/**
* The results of node executions
*/
results: Record<string, (ImageOutput | MaskOutput | ControlOutput | PromptOutput | CompelOutput | IntOutput | FloatOutput | LatentsOutput | NoiseOutput | IntCollectionOutput | FloatCollectionOutput | GraphInvocationOutput | IterateInvocationOutput | CollectInvocationOutput)>;
results: Record<string, (ImageOutput | MaskOutput | ControlOutput | PromptOutput | PromptCollectionOutput | CompelOutput | IntOutput | FloatOutput | LatentsOutput | NoiseOutput | IntCollectionOutput | FloatCollectionOutput | GraphInvocationOutput | IterateInvocationOutput | CollectInvocationOutput)>;
/**
* Errors raised when executing nodes
*/

View File

@ -0,0 +1,19 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* Base class for invocations that output a collection of prompts
*/
export type PromptCollectionOutput = {
type: 'prompt_collection_output';
/**
* The output prompt collection
*/
prompt_collection: Array<string>;
/**
* The size of the prompt collection
*/
count: number;
};

View File

@ -9,6 +9,7 @@ import type { ContentShuffleImageProcessorInvocation } from '../models/ContentSh
import type { ControlNetInvocation } from '../models/ControlNetInvocation';
import type { CvInpaintInvocation } from '../models/CvInpaintInvocation';
import type { DivideInvocation } from '../models/DivideInvocation';
import type { DynamicPromptInvocation } from '../models/DynamicPromptInvocation';
import type { Edge } from '../models/Edge';
import type { FloatLinearRangeInvocation } from '../models/FloatLinearRangeInvocation';
import type { Graph } from '../models/Graph';
@ -173,7 +174,7 @@ export class SessionsService {
* The id of the session
*/
sessionId: string,
requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation),
requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation),
}): CancelablePromise<string> {
return __request(OpenAPI, {
method: 'POST',
@ -210,7 +211,7 @@ export class SessionsService {
* The path to the node in the graph
*/
nodePath: string,
requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation),
requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation),
}): CancelablePromise<GraphExecutionState> {
return __request(OpenAPI, {
method: 'PUT',

View File

@ -44,6 +44,7 @@ dependencies = [
"datasets",
"diffusers[torch]~=0.17.0",
"dnspython==2.2.1",
"dynamicprompts",
"easing-functions",
"einops",
"eventlet",