diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index baf558ac24..8cfe35598d 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -1,10 +1,11 @@ -# InvokeAI nodes for ControlNet image preprocessors +# Invocations for ControlNet image preprocessors # initial implementation by Gregg Helt, 2023 # heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux from builtins import float, bool +import cv2 import numpy as np -from typing import Literal, Optional, Union, List +from typing import Literal, Optional, Union, List, Dict from PIL import Image, ImageFilter, ImageOps from pydantic import BaseModel, Field, validator @@ -29,8 +30,13 @@ from controlnet_aux import ( ContentShuffleDetector, ZoeDetector, MediapipeFaceDetector, + SamDetector, + LeresDetector, ) +from controlnet_aux.util import HWC3, ade_palette + + from .image import ImageOutput, PILInvocationConfig CONTROLNET_DEFAULT_MODELS = [ @@ -95,6 +101,9 @@ CONTROLNET_DEFAULT_MODELS = [ CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)] CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])] +# crop and fill options not ready yet +# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])] + class ControlField(BaseModel): image: ImageField = Field(default=None, description="The control image") @@ -105,7 +114,8 @@ class ControlField(BaseModel): description="When the ControlNet is first applied (% of total steps)") end_step_percent: float = Field(default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)") - control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The contorl mode to use") + control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use") + # resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use") @validator("control_weight") def abs_le_one(cls, v): @@ -180,7 +190,7 @@ class ControlNetInvocation(BaseInvocation): ), ) -# TODO: move image processors to separate file (image_analysis.py + class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig): """Base class for invocations that preprocess images for ControlNet""" @@ -452,6 +462,104 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo # fmt: on def run_processor(self, image): + # MediaPipeFaceDetector throws an error if image has alpha channel + # so convert to RGB if needed + if image.mode == 'RGBA': + image = image.convert('RGB') mediapipe_face_processor = MediapipeFaceDetector() processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence) return processed_image + +class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Applies leres processing to image""" + # fmt: off + type: Literal["leres_image_processor"] = "leres_image_processor" + # Inputs + thr_a: float = Field(default=0, description="Leres parameter `thr_a`") + thr_b: float = Field(default=0, description="Leres parameter `thr_b`") + boost: bool = Field(default=False, description="Whether to use boost mode") + detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") + image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") + # fmt: on + + def run_processor(self, image): + leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators") + processed_image = leres_processor(image, + thr_a=self.thr_a, + thr_b=self.thr_b, + boost=self.boost, + detect_resolution=self.detect_resolution, + image_resolution=self.image_resolution) + return processed_image + + +class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + + # fmt: off + type: Literal["tile_image_processor"] = "tile_image_processor" + # Inputs + #res: int = Field(default=512, ge=0, le=1024, description="The pixel resolution for each tile") + down_sampling_rate: float = Field(default=1.0, ge=1.0, le=8.0, description="Down sampling rate") + # fmt: on + + # tile_resample copied from sd-webui-controlnet/scripts/processor.py + def tile_resample(self, + np_img: np.ndarray, + res=512, # never used? + down_sampling_rate=1.0, + ): + np_img = HWC3(np_img) + if down_sampling_rate < 1.1: + return np_img + H, W, C = np_img.shape + H = int(float(H) / float(down_sampling_rate)) + W = int(float(W) / float(down_sampling_rate)) + np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA) + return np_img + + def run_processor(self, img): + np_img = np.array(img, dtype=np.uint8) + processed_np_image = self.tile_resample(np_img, + #res=self.tile_size, + down_sampling_rate=self.down_sampling_rate + ) + processed_image = Image.fromarray(processed_np_image) + return processed_image + + + + +class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Applies segment anything processing to image""" + # fmt: off + type: Literal["segment_anything_processor"] = "segment_anything_processor" + # fmt: on + + def run_processor(self, image): + # segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") + segment_anything_processor = SamDetectorReproducibleColors.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") + np_img = np.array(image, dtype=np.uint8) + processed_image = segment_anything_processor(np_img) + return processed_image + +class SamDetectorReproducibleColors(SamDetector): + + # overriding SamDetector.show_anns() method to use reproducible colors for segmentation image + # base class show_anns() method randomizes colors, + # which seems to also lead to non-reproducible image generation + # so using ADE20k color palette instead + def show_anns(self, anns: List[Dict]): + if len(anns) == 0: + return + sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) + h, w = anns[0]['segmentation'].shape + final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB") + palette = ade_palette() + for i, ann in enumerate(sorted_anns): + m = ann['segmentation'] + img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8) + # doing modulo just in case number of annotated regions exceeds number of colors in palette + ann_color = palette[i % len(palette)] + img[:, :] = ann_color + final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255))) + return np.array(final_img, dtype=np.uint8) diff --git a/invokeai/app/invocations/param_easing.py b/invokeai/app/invocations/param_easing.py index 1ff6261b88..e79763a35e 100644 --- a/invokeai/app/invocations/param_easing.py +++ b/invokeai/app/invocations/param_easing.py @@ -133,20 +133,19 @@ class StepParamEasingInvocation(BaseInvocation): postlist = list(num_poststeps * [self.post_end_value]) if log_diagnostics: - logger = InvokeAILogger.getLogger(name="StepParamEasing") - logger.debug("start_step: " + str(start_step)) - logger.debug("end_step: " + str(end_step)) - logger.debug("num_easing_steps: " + str(num_easing_steps)) - logger.debug("num_presteps: " + str(num_presteps)) - logger.debug("num_poststeps: " + str(num_poststeps)) - logger.debug("prelist size: " + str(len(prelist))) - logger.debug("postlist size: " + str(len(postlist))) - logger.debug("prelist: " + str(prelist)) - logger.debug("postlist: " + str(postlist)) + context.services.logger.debug("start_step: " + str(start_step)) + context.services.logger.debug("end_step: " + str(end_step)) + context.services.logger.debug("num_easing_steps: " + str(num_easing_steps)) + context.services.logger.debug("num_presteps: " + str(num_presteps)) + context.services.logger.debug("num_poststeps: " + str(num_poststeps)) + context.services.logger.debug("prelist size: " + str(len(prelist))) + context.services.logger.debug("postlist size: " + str(len(postlist))) + context.services.logger.debug("prelist: " + str(prelist)) + context.services.logger.debug("postlist: " + str(postlist)) easing_class = EASING_FUNCTIONS_MAP[self.easing] if log_diagnostics: - logger.debug("easing class: " + str(easing_class)) + context.services.logger.debug("easing class: " + str(easing_class)) easing_list = list() if self.mirror: # "expected" mirroring # if number of steps is even, squeeze duration down to (number_of_steps)/2 @@ -156,7 +155,7 @@ class StepParamEasingInvocation(BaseInvocation): # but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always base_easing_duration = int(np.ceil(num_easing_steps/2.0)) - if log_diagnostics: logger.debug("base easing duration: " + str(base_easing_duration)) + if log_diagnostics: context.services.logger.debug("base easing duration: " + str(base_easing_duration)) even_num_steps = (num_easing_steps % 2 == 0) # even number of steps easing_function = easing_class(start=self.start_value, end=self.end_value, @@ -166,14 +165,14 @@ class StepParamEasingInvocation(BaseInvocation): easing_val = easing_function.ease(step_index) base_easing_vals.append(easing_val) if log_diagnostics: - logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val)) + context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val)) if even_num_steps: mirror_easing_vals = list(reversed(base_easing_vals)) else: mirror_easing_vals = list(reversed(base_easing_vals[0:-1])) if log_diagnostics: - logger.debug("base easing vals: " + str(base_easing_vals)) - logger.debug("mirror easing vals: " + str(mirror_easing_vals)) + context.services.logger.debug("base easing vals: " + str(base_easing_vals)) + context.services.logger.debug("mirror easing vals: " + str(mirror_easing_vals)) easing_list = base_easing_vals + mirror_easing_vals # FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely @@ -206,12 +205,12 @@ class StepParamEasingInvocation(BaseInvocation): step_val = easing_function.ease(step_index) easing_list.append(step_val) if log_diagnostics: - logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val)) + context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val)) if log_diagnostics: - logger.debug("prelist size: " + str(len(prelist))) - logger.debug("easing_list size: " + str(len(easing_list))) - logger.debug("postlist size: " + str(len(postlist))) + context.services.logger.debug("prelist size: " + str(len(prelist))) + context.services.logger.debug("easing_list size: " + str(len(easing_list))) + context.services.logger.debug("postlist size: " + str(len(postlist))) param_list = prelist + easing_list + postlist diff --git a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCollapse.tsx b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCollapse.tsx index eeaf1b81ec..1aefecf3e6 100644 --- a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCollapse.tsx +++ b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCollapse.tsx @@ -35,8 +35,8 @@ const ParamDynamicPromptsCollapse = () => { withSwitch > - + ); diff --git a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx index ab56abaa35..19f02ae3e5 100644 --- a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx +++ b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx @@ -9,17 +9,18 @@ import { stateSelector } from 'app/store/store'; const selector = createSelector( stateSelector, (state) => { - const { maxPrompts } = state.dynamicPrompts; + const { maxPrompts, combinatorial } = state.dynamicPrompts; const { min, sliderMax, inputMax } = state.config.sd.dynamicPrompts.maxPrompts; - return { maxPrompts, min, sliderMax, inputMax }; + return { maxPrompts, min, sliderMax, inputMax, combinatorial }; }, defaultSelectorOptions ); const ParamDynamicPromptsMaxPrompts = () => { - const { maxPrompts, min, sliderMax, inputMax } = useAppSelector(selector); + const { maxPrompts, min, sliderMax, inputMax, combinatorial } = + useAppSelector(selector); const dispatch = useAppDispatch(); const handleChange = useCallback( @@ -36,6 +37,7 @@ const ParamDynamicPromptsMaxPrompts = () => { return ( { state.config.sd.iterations; const { iterations } = state.generation; const { shouldUseSliders } = state.ui; - const isDisabled = state.dynamicPrompts.isEnabled; + const isDisabled = + state.dynamicPrompts.isEnabled && state.dynamicPrompts.combinatorial; const step = state.hotkeys.shift ? fineStep : coarseStep; diff --git a/invokeai/frontend/web/src/services/api/types.d.ts b/invokeai/frontend/web/src/services/api/types.d.ts index a995d9c298..2a2f90f434 100644 --- a/invokeai/frontend/web/src/services/api/types.d.ts +++ b/invokeai/frontend/web/src/services/api/types.d.ts @@ -4,91 +4,89 @@ import { components } from './schema'; type schemas = components['schemas']; /** - * Helper type to extract the invocation type from the schema. - * Also flags the `type` property as required. + * Extracts the schema type from the schema. */ -type Invocation = O.Required; +type S = components['schemas'][T]; /** - * Types from the API, re-exported from the types generated by `openapi-typescript`. + * Extracts the node type from the schema. + * Also flags the `type` property as required. */ +type N = O.Required< + components['schemas'][T], + 'type' +>; // Images -export type ImageDTO = schemas['ImageDTO']; -export type BoardDTO = schemas['BoardDTO']; -export type BoardChanges = schemas['BoardChanges']; -export type ImageChanges = schemas['ImageRecordChanges']; -export type ImageCategory = schemas['ImageCategory']; -export type ResourceOrigin = schemas['ResourceOrigin']; -export type ImageField = schemas['ImageField']; +export type ImageDTO = S<'ImageDTO'>; +export type BoardDTO = S<'BoardDTO'>; +export type BoardChanges = S<'BoardChanges'>; +export type ImageChanges = S<'ImageRecordChanges'>; +export type ImageCategory = S<'ImageCategory'>; +export type ResourceOrigin = S<'ResourceOrigin'>; +export type ImageField = S<'ImageField'>; export type OffsetPaginatedResults_BoardDTO_ = - schemas['OffsetPaginatedResults_BoardDTO_']; + S<'OffsetPaginatedResults_BoardDTO_'>; export type OffsetPaginatedResults_ImageDTO_ = - schemas['OffsetPaginatedResults_ImageDTO_']; + S<'OffsetPaginatedResults_ImageDTO_'>; // Models -export type ModelType = schemas['ModelType']; -export type BaseModelType = schemas['BaseModelType']; -export type PipelineModelField = schemas['PipelineModelField']; -export type ModelsList = schemas['ModelsList']; +export type ModelType = S<'ModelType'>; +export type BaseModelType = S<'BaseModelType'>; +export type PipelineModelField = S<'PipelineModelField'>; +export type ModelsList = S<'ModelsList'>; // Graphs -export type Graph = schemas['Graph']; -export type Edge = schemas['Edge']; -export type GraphExecutionState = schemas['GraphExecutionState']; +export type Graph = S<'Graph'>; +export type Edge = S<'Edge'>; +export type GraphExecutionState = S<'GraphExecutionState'>; // General nodes -export type CollectInvocation = Invocation<'CollectInvocation'>; -export type IterateInvocation = Invocation<'IterateInvocation'>; -export type RangeInvocation = Invocation<'RangeInvocation'>; -export type RandomRangeInvocation = Invocation<'RandomRangeInvocation'>; -export type RangeOfSizeInvocation = Invocation<'RangeOfSizeInvocation'>; -export type InpaintInvocation = Invocation<'InpaintInvocation'>; -export type ImageResizeInvocation = Invocation<'ImageResizeInvocation'>; -export type RandomIntInvocation = Invocation<'RandomIntInvocation'>; -export type CompelInvocation = Invocation<'CompelInvocation'>; -export type DynamicPromptInvocation = Invocation<'DynamicPromptInvocation'>; -export type NoiseInvocation = Invocation<'NoiseInvocation'>; -export type TextToLatentsInvocation = Invocation<'TextToLatentsInvocation'>; -export type LatentsToLatentsInvocation = - Invocation<'LatentsToLatentsInvocation'>; -export type ImageToLatentsInvocation = Invocation<'ImageToLatentsInvocation'>; -export type LatentsToImageInvocation = Invocation<'LatentsToImageInvocation'>; -export type PipelineModelLoaderInvocation = - Invocation<'PipelineModelLoaderInvocation'>; +export type CollectInvocation = N<'CollectInvocation'>; +export type IterateInvocation = N<'IterateInvocation'>; +export type RangeInvocation = N<'RangeInvocation'>; +export type RandomRangeInvocation = N<'RandomRangeInvocation'>; +export type RangeOfSizeInvocation = N<'RangeOfSizeInvocation'>; +export type InpaintInvocation = N<'InpaintInvocation'>; +export type ImageResizeInvocation = N<'ImageResizeInvocation'>; +export type RandomIntInvocation = N<'RandomIntInvocation'>; +export type CompelInvocation = N<'CompelInvocation'>; +export type DynamicPromptInvocation = N<'DynamicPromptInvocation'>; +export type NoiseInvocation = N<'NoiseInvocation'>; +export type TextToLatentsInvocation = N<'TextToLatentsInvocation'>; +export type LatentsToLatentsInvocation = N<'LatentsToLatentsInvocation'>; +export type ImageToLatentsInvocation = N<'ImageToLatentsInvocation'>; +export type LatentsToImageInvocation = N<'LatentsToImageInvocation'>; +export type PipelineModelLoaderInvocation = N<'PipelineModelLoaderInvocation'>; // ControlNet Nodes -export type ControlNetInvocation = Invocation<'ControlNetInvocation'>; -export type CannyImageProcessorInvocation = - Invocation<'CannyImageProcessorInvocation'>; +export type ControlNetInvocation = N<'ControlNetInvocation'>; +export type CannyImageProcessorInvocation = N<'CannyImageProcessorInvocation'>; export type ContentShuffleImageProcessorInvocation = - Invocation<'ContentShuffleImageProcessorInvocation'>; -export type HedImageProcessorInvocation = - Invocation<'HedImageProcessorInvocation'>; + N<'ContentShuffleImageProcessorInvocation'>; +export type HedImageProcessorInvocation = N<'HedImageProcessorInvocation'>; export type LineartAnimeImageProcessorInvocation = - Invocation<'LineartAnimeImageProcessorInvocation'>; + N<'LineartAnimeImageProcessorInvocation'>; export type LineartImageProcessorInvocation = - Invocation<'LineartImageProcessorInvocation'>; + N<'LineartImageProcessorInvocation'>; export type MediapipeFaceProcessorInvocation = - Invocation<'MediapipeFaceProcessorInvocation'>; + N<'MediapipeFaceProcessorInvocation'>; export type MidasDepthImageProcessorInvocation = - Invocation<'MidasDepthImageProcessorInvocation'>; -export type MlsdImageProcessorInvocation = - Invocation<'MlsdImageProcessorInvocation'>; + N<'MidasDepthImageProcessorInvocation'>; +export type MlsdImageProcessorInvocation = N<'MlsdImageProcessorInvocation'>; export type NormalbaeImageProcessorInvocation = - Invocation<'NormalbaeImageProcessorInvocation'>; + N<'NormalbaeImageProcessorInvocation'>; export type OpenposeImageProcessorInvocation = - Invocation<'OpenposeImageProcessorInvocation'>; -export type PidiImageProcessorInvocation = - Invocation<'PidiImageProcessorInvocation'>; + N<'OpenposeImageProcessorInvocation'>; +export type PidiImageProcessorInvocation = N<'PidiImageProcessorInvocation'>; export type ZoeDepthImageProcessorInvocation = - Invocation<'ZoeDepthImageProcessorInvocation'>; + N<'ZoeDepthImageProcessorInvocation'>; // Node Outputs -export type ImageOutput = schemas['ImageOutput']; -export type MaskOutput = schemas['MaskOutput']; -export type PromptOutput = schemas['PromptOutput']; -export type IterateInvocationOutput = schemas['IterateInvocationOutput']; -export type CollectInvocationOutput = schemas['CollectInvocationOutput']; -export type LatentsOutput = schemas['LatentsOutput']; -export type GraphInvocationOutput = schemas['GraphInvocationOutput']; +export type ImageOutput = S<'ImageOutput'>; +export type MaskOutput = S<'MaskOutput'>; +export type PromptOutput = S<'PromptOutput'>; +export type IterateInvocationOutput = S<'IterateInvocationOutput'>; +export type CollectInvocationOutput = S<'CollectInvocationOutput'>; +export type LatentsOutput = S<'LatentsOutput'>; +export type GraphInvocationOutput = S<'GraphInvocationOutput'>; diff --git a/pyproject.toml b/pyproject.toml index 03396312ac..6e5b8f4e22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ dependencies = [ "click", "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "compel>=1.2.1", - "controlnet-aux>=0.0.4", + "controlnet-aux>=0.0.6", "timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26 "datasets", "diffusers[torch]~=0.17.1",