mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Feat/easy param (#3504)
* Testing change to LatentsToText to allow setting different cfg_scale values per diffusion step. * Adding first attempt at float param easing node, using Penner easing functions. * Core implementation of ControlNet and MultiControlNet. * Added support for ControlNet and MultiControlNet to legacy non-nodal Txt2Img in backend/generator. Although backend/generator will likely disappear by v3.x, right now they are very useful for testing core ControlNet and MultiControlNet functionality while node codebase is rapidly evolving. * Added example of using ControlNet with legacy Txt2Img generator * Resolving rebase conflict * Added first controlnet preprocessor node for canny edge detection. * Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node * Switching to ControlField for output from controlnet nodes. * Resolving conflicts in rebase to origin/main * Refactored ControlNet nodes so they subclass from PreprocessedControlInvocation, and only need to override run_processor(image) (instead of reimplementing invoke()) * changes to base class for controlnet nodes * Added HED, LineArt, and OpenPose ControlNet nodes * Added an additional "raw_processed_image" output port to controlnets, mainly so could route ImageField to a ShowImage node * Added more preprocessor nodes for: MidasDepth ZoeDepth MLSD NormalBae Pidi LineartAnime ContentShuffle Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup. * Prep for splitting pre-processor and controlnet nodes * Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes. * Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue. * More rebase repair. * Added support for using multiple control nets. Unfortunately this breaks direct usage of Control node output port ==> TextToLatent control input port -- passing through a Collect node is now required. Working on fixing this... * Fixed use of ControlNet control_weight parameter * Fixed lint-ish formatting error * Core implementation of ControlNet and MultiControlNet. * Added first controlnet preprocessor node for canny edge detection. * Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node * Switching to ControlField for output from controlnet nodes. * Refactored controlnet node to output ControlField that bundles control info. * changes to base class for controlnet nodes * Added more preprocessor nodes for: MidasDepth ZoeDepth MLSD NormalBae Pidi LineartAnime ContentShuffle Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup. * Prep for splitting pre-processor and controlnet nodes * Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes. * Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue. * Cleaning up TextToLatent arg testing * Cleaning up mistakes after rebase. * Removed last bits of dtype and and device hardwiring from controlnet section * Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled. * Added support for specifying which step iteration to start using each ControlNet, and which step to end using each controlnet (specified as fraction of total steps) * Cleaning up prior to submitting ControlNet PR. Mostly turning off diagnostic printing. Also fixed error when there is no controlnet input. * Added dependency on controlnet-aux v0.0.3 * Commented out ZoeDetector. Will re-instate once there's a controlnet-aux release that supports it. * Switched CotrolNet node modelname input from free text to default list of popular ControlNet model names. * Fix to work with current stable release of controlnet_aux (v0.0.3). Turned of pre-processor params that were added post v0.0.3. Also change defaults for shuffle. * Refactored most of controlnet code into its own method to declutter TextToLatents.invoke(), and make upcoming integration with LatentsToLatents easier. * Cleaning up after ControlNet refactor in TextToLatentsInvocation * Extended node-based ControlNet support to LatentsToLatentsInvocation. * chore(ui): regen api client * fix(ui): add value to conditioning field * fix(ui): add control field type * fix(ui): fix node ui type hints * fix(nodes): controlnet input accepts list or single controlnet * Moved to controlnet_aux v0.0.4, reinstated Zoe controlnet preprocessor. Also in pyproject.toml had to specify downgrade of timm to 0.6.13 _after_ controlnet-aux installs timm >= 0.9.2, because timm >0.6.13 breaks Zoe preprocessor. * Core implementation of ControlNet and MultiControlNet. * Added first controlnet preprocessor node for canny edge detection. * Switching to ControlField for output from controlnet nodes. * Resolving conflicts in rebase to origin/main * Refactored ControlNet nodes so they subclass from PreprocessedControlInvocation, and only need to override run_processor(image) (instead of reimplementing invoke()) * changes to base class for controlnet nodes * Added HED, LineArt, and OpenPose ControlNet nodes * Added more preprocessor nodes for: MidasDepth ZoeDepth MLSD NormalBae Pidi LineartAnime ContentShuffle Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup. * Prep for splitting pre-processor and controlnet nodes * Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes. * Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue. * Added support for using multiple control nets. Unfortunately this breaks direct usage of Control node output port ==> TextToLatent control input port -- passing through a Collect node is now required. Working on fixing this... * Fixed use of ControlNet control_weight parameter * Core implementation of ControlNet and MultiControlNet. * Added first controlnet preprocessor node for canny edge detection. * Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node * Switching to ControlField for output from controlnet nodes. * Refactored controlnet node to output ControlField that bundles control info. * changes to base class for controlnet nodes * Added more preprocessor nodes for: MidasDepth ZoeDepth MLSD NormalBae Pidi LineartAnime ContentShuffle Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup. * Prep for splitting pre-processor and controlnet nodes * Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes. * Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue. * Cleaning up TextToLatent arg testing * Cleaning up mistakes after rebase. * Removed last bits of dtype and and device hardwiring from controlnet section * Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled. * Added support for specifying which step iteration to start using each ControlNet, and which step to end using each controlnet (specified as fraction of total steps) * Cleaning up prior to submitting ControlNet PR. Mostly turning off diagnostic printing. Also fixed error when there is no controlnet input. * Commented out ZoeDetector. Will re-instate once there's a controlnet-aux release that supports it. * Switched CotrolNet node modelname input from free text to default list of popular ControlNet model names. * Fix to work with current stable release of controlnet_aux (v0.0.3). Turned of pre-processor params that were added post v0.0.3. Also change defaults for shuffle. * Refactored most of controlnet code into its own method to declutter TextToLatents.invoke(), and make upcoming integration with LatentsToLatents easier. * Cleaning up after ControlNet refactor in TextToLatentsInvocation * Extended node-based ControlNet support to LatentsToLatentsInvocation. * chore(ui): regen api client * fix(ui): fix node ui type hints * fix(nodes): controlnet input accepts list or single controlnet * Added Mediapipe image processor for use as ControlNet preprocessor. Also hacked in ability to specify HF subfolder when loading ControlNet models from string. * Fixed bug where MediapipFaceProcessorInvocation was ignoring max_faces and min_confidence params. * Added nodes for float params: ParamFloatInvocation and FloatCollectionOutput. Also added FloatOutput. * Added mediapipe install requirement. Should be able to remove once controlnet_aux package adds mediapipe to its requirements. * Added float to FIELD_TYPE_MAP ins constants.ts * Progress toward improvement in fieldTemplateBuilder.ts getFieldType() * Fixed controlnet preprocessors and controlnet handling in TextToLatents to work with revised Image services. * Cleaning up from merge, re-adding cfg_scale to FIELD_TYPE_MAP * Making sure cfg_scale of type list[float] can be used in image metadata, to support param easing for cfg_scale * Fixed math for per-step param easing. * Added option to show plot of param value at each step * Just cleaning up after adding param easing plot option, removing vestigial code. * Modified control_weight ControlNet param to be polistmorphic -- can now be either a single float weight applied for all steps, or a list of floats of size total_steps, that specifies weight for each step. * Added more informative error message when _validat_edge() throws an error. * Just improving parm easing bar chart title to include easing type. * Added requirement for easing-functions package * Taking out some diagnostic prints. * Added option to use both easing function and mirror of easing function together. * Fixed recently introduced problem (when pulled in main), triggered by num_steps in StepParamEasingInvocation not having a default value -- just added default. --------- Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
parent
30f20b55d5
commit
c647056287
@ -1,11 +1,12 @@
|
||||
# InvokeAI nodes for ControlNet image preprocessors
|
||||
# initial implementation by Gregg Helt, 2023
|
||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||
from builtins import float
|
||||
|
||||
import numpy as np
|
||||
from typing import Literal, Optional, Union, List
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from ..models.image import ImageField, ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
@ -14,6 +15,7 @@ from .baseinvocation import (
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
|
||||
from controlnet_aux import (
|
||||
CannyDetector,
|
||||
HEDdetector,
|
||||
@ -96,15 +98,32 @@ CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(default=None, description="The control image")
|
||||
control_model: Optional[str] = Field(default=None, description="The ControlNet model to use")
|
||||
control_weight: Optional[float] = Field(default=1, description="The weight given to the ControlNet")
|
||||
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||
description="When the ControlNet is first applied (% of total steps)")
|
||||
description="When the ControlNet is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
|
||||
@validator("control_weight")
|
||||
def abs_le_one(cls, v):
|
||||
"""validate that all abs(values) are <=1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if abs(i) > 1:
|
||||
raise ValueError('all abs(control_weight) must be <= 1')
|
||||
else:
|
||||
if abs(v) > 1:
|
||||
raise ValueError('abs(control_weight) must be <= 1')
|
||||
return v
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"]
|
||||
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"],
|
||||
"ui": {
|
||||
"type_hints": {
|
||||
"control_weight": "float",
|
||||
# "control_weight": "number",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -112,7 +131,7 @@ class ControlOutput(BaseInvocationOutput):
|
||||
"""node output for ControlNet info"""
|
||||
# fmt: off
|
||||
type: Literal["control_output"] = "control_output"
|
||||
control: ControlField = Field(default=None, description="The output control image")
|
||||
control: ControlField = Field(default=None, description="The control info")
|
||||
# fmt: on
|
||||
|
||||
|
||||
@ -123,15 +142,28 @@ class ControlNetInvocation(BaseInvocation):
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The control image")
|
||||
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
|
||||
description="The ControlNet model to use")
|
||||
control_weight: float = Field(default=1.0, ge=0, le=1, description="The weight given to the ControlNet")
|
||||
description="control model used")
|
||||
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
|
||||
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
|
||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||
description="When the ControlNet is first applied (% of total steps)")
|
||||
description="When the ControlNet is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number",
|
||||
"control_weight": "float",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||
|
||||
@ -161,7 +193,6 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
return image
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
|
||||
raw_image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
@ -174,22 +174,36 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
control: Union[ControlField, List[ControlField]] = Field(default=None, description="The control to use")
|
||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
return v
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number"
|
||||
}
|
||||
},
|
||||
}
|
||||
@ -244,10 +258,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
c,
|
||||
self.cfg_scale,
|
||||
extra_conditioning_info,
|
||||
unconditioned_embeddings=uc,
|
||||
text_embeddings=c,
|
||||
guidance_scale=self.cfg_scale,
|
||||
extra=extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=0.0,#threshold,
|
||||
warmup=0.2,#warmup,
|
||||
@ -348,7 +362,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||
@ -385,6 +400,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
"cfg_scale": "number",
|
||||
}
|
||||
},
|
||||
}
|
||||
@ -403,10 +419,11 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(context, model)
|
||||
|
||||
print("type of control input: ", type(self.control))
|
||||
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
|
||||
|
237
invokeai/app/invocations/param_easing.py
Normal file
237
invokeai/app/invocations/param_easing.py
Normal file
@ -0,0 +1,237 @@
|
||||
import io
|
||||
from typing import Literal, Optional, Any
|
||||
|
||||
# from PIL.Image import Image
|
||||
import PIL.Image
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from easing_functions import (
|
||||
LinearInOut,
|
||||
QuadEaseInOut, QuadEaseIn, QuadEaseOut,
|
||||
CubicEaseInOut, CubicEaseIn, CubicEaseOut,
|
||||
QuarticEaseInOut, QuarticEaseIn, QuarticEaseOut,
|
||||
QuinticEaseInOut, QuinticEaseIn, QuinticEaseOut,
|
||||
SineEaseInOut, SineEaseIn, SineEaseOut,
|
||||
CircularEaseIn, CircularEaseInOut, CircularEaseOut,
|
||||
ExponentialEaseInOut, ExponentialEaseIn, ExponentialEaseOut,
|
||||
ElasticEaseIn, ElasticEaseInOut, ElasticEaseOut,
|
||||
BackEaseIn, BackEaseInOut, BackEaseOut,
|
||||
BounceEaseIn, BounceEaseInOut, BounceEaseOut)
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
from ...backend.util.logging import InvokeAILogger
|
||||
from .collections import FloatCollectionOutput
|
||||
|
||||
|
||||
class FloatLinearRangeInvocation(BaseInvocation):
|
||||
"""Creates a range"""
|
||||
|
||||
type: Literal["float_range"] = "float_range"
|
||||
|
||||
# Inputs
|
||||
start: float = Field(default=5, description="The first value of the range")
|
||||
stop: float = Field(default=10, description="The last value of the range")
|
||||
steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
||||
return FloatCollectionOutput(
|
||||
collection=param_list
|
||||
)
|
||||
|
||||
|
||||
EASING_FUNCTIONS_MAP = {
|
||||
"Linear": LinearInOut,
|
||||
"QuadIn": QuadEaseIn,
|
||||
"QuadOut": QuadEaseOut,
|
||||
"QuadInOut": QuadEaseInOut,
|
||||
"CubicIn": CubicEaseIn,
|
||||
"CubicOut": CubicEaseOut,
|
||||
"CubicInOut": CubicEaseInOut,
|
||||
"QuarticIn": QuarticEaseIn,
|
||||
"QuarticOut": QuarticEaseOut,
|
||||
"QuarticInOut": QuarticEaseInOut,
|
||||
"QuinticIn": QuinticEaseIn,
|
||||
"QuinticOut": QuinticEaseOut,
|
||||
"QuinticInOut": QuinticEaseInOut,
|
||||
"SineIn": SineEaseIn,
|
||||
"SineOut": SineEaseOut,
|
||||
"SineInOut": SineEaseInOut,
|
||||
"CircularIn": CircularEaseIn,
|
||||
"CircularOut": CircularEaseOut,
|
||||
"CircularInOut": CircularEaseInOut,
|
||||
"ExponentialIn": ExponentialEaseIn,
|
||||
"ExponentialOut": ExponentialEaseOut,
|
||||
"ExponentialInOut": ExponentialEaseInOut,
|
||||
"ElasticIn": ElasticEaseIn,
|
||||
"ElasticOut": ElasticEaseOut,
|
||||
"ElasticInOut": ElasticEaseInOut,
|
||||
"BackIn": BackEaseIn,
|
||||
"BackOut": BackEaseOut,
|
||||
"BackInOut": BackEaseInOut,
|
||||
"BounceIn": BounceEaseIn,
|
||||
"BounceOut": BounceEaseOut,
|
||||
"BounceInOut": BounceEaseInOut,
|
||||
}
|
||||
|
||||
EASING_FUNCTION_KEYS: Any = Literal[
|
||||
tuple(list(EASING_FUNCTIONS_MAP.keys()))
|
||||
]
|
||||
|
||||
|
||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||
class StepParamEasingInvocation(BaseInvocation):
|
||||
"""Experimental per-step parameter easing for denoising steps"""
|
||||
|
||||
type: Literal["step_param_easing"] = "step_param_easing"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
easing: EASING_FUNCTION_KEYS = Field(default="Linear", description="The easing function to use")
|
||||
num_steps: int = Field(default=20, description="number of denoising steps")
|
||||
start_value: float = Field(default=0.0, description="easing starting value")
|
||||
end_value: float = Field(default=1.0, description="easing ending value")
|
||||
start_step_percent: float = Field(default=0.0, description="fraction of steps at which to start easing")
|
||||
end_step_percent: float = Field(default=1.0, description="fraction of steps after which to end easing")
|
||||
# if None, then start_value is used prior to easing start
|
||||
pre_start_value: Optional[float] = Field(default=None, description="value before easing start")
|
||||
# if None, then end value is used prior to easing end
|
||||
post_end_value: Optional[float] = Field(default=None, description="value after easing end")
|
||||
mirror: bool = Field(default=False, description="include mirror of easing function")
|
||||
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||
# alt_mirror: bool = Field(default=False, description="alternative mirroring by dual easing")
|
||||
show_easing_plot: bool = Field(default=False, description="show easing plot")
|
||||
# fmt: on
|
||||
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
log_diagnostics = False
|
||||
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
|
||||
# start_step = int(np.floor(self.num_steps * self.start_step_percent))
|
||||
start_step = int(np.round(self.num_steps * self.start_step_percent))
|
||||
# convert from end_step_percent to nearest step >= (steps * end_step_percent)
|
||||
# end_step = int(np.ceil((self.num_steps - 1) * self.end_step_percent))
|
||||
end_step = int(np.round((self.num_steps - 1) * self.end_step_percent))
|
||||
|
||||
# end_step = int(np.ceil(self.num_steps * self.end_step_percent))
|
||||
num_easing_steps = end_step - start_step + 1
|
||||
|
||||
# num_presteps = max(start_step - 1, 0)
|
||||
num_presteps = start_step
|
||||
num_poststeps = self.num_steps - (num_presteps + num_easing_steps)
|
||||
prelist = list(num_presteps * [self.pre_start_value])
|
||||
postlist = list(num_poststeps * [self.post_end_value])
|
||||
|
||||
if log_diagnostics:
|
||||
logger = InvokeAILogger.getLogger(name="StepParamEasing")
|
||||
logger.debug("start_step: " + str(start_step))
|
||||
logger.debug("end_step: " + str(end_step))
|
||||
logger.debug("num_easing_steps: " + str(num_easing_steps))
|
||||
logger.debug("num_presteps: " + str(num_presteps))
|
||||
logger.debug("num_poststeps: " + str(num_poststeps))
|
||||
logger.debug("prelist size: " + str(len(prelist)))
|
||||
logger.debug("postlist size: " + str(len(postlist)))
|
||||
logger.debug("prelist: " + str(prelist))
|
||||
logger.debug("postlist: " + str(postlist))
|
||||
|
||||
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
||||
if log_diagnostics:
|
||||
logger.debug("easing class: " + str(easing_class))
|
||||
easing_list = list()
|
||||
if self.mirror: # "expected" mirroring
|
||||
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
||||
# and create reverse copy of list to append
|
||||
# if number of steps is odd, squeeze duration down to ceil(number_of_steps/2)
|
||||
# and create reverse copy of list[1:end-1]
|
||||
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
|
||||
|
||||
base_easing_duration = int(np.ceil(num_easing_steps/2.0))
|
||||
if log_diagnostics: logger.debug("base easing duration: " + str(base_easing_duration))
|
||||
even_num_steps = (num_easing_steps % 2 == 0) # even number of steps
|
||||
easing_function = easing_class(start=self.start_value,
|
||||
end=self.end_value,
|
||||
duration=base_easing_duration - 1)
|
||||
base_easing_vals = list()
|
||||
for step_index in range(base_easing_duration):
|
||||
easing_val = easing_function.ease(step_index)
|
||||
base_easing_vals.append(easing_val)
|
||||
if log_diagnostics:
|
||||
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
|
||||
if even_num_steps:
|
||||
mirror_easing_vals = list(reversed(base_easing_vals))
|
||||
else:
|
||||
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
|
||||
if log_diagnostics:
|
||||
logger.debug("base easing vals: " + str(base_easing_vals))
|
||||
logger.debug("mirror easing vals: " + str(mirror_easing_vals))
|
||||
easing_list = base_easing_vals + mirror_easing_vals
|
||||
|
||||
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||
# elif self.alt_mirror: # function mirroring (unintuitive behavior (at least to me))
|
||||
# # half_ease_duration = round(num_easing_steps - 1 / 2)
|
||||
# half_ease_duration = round((num_easing_steps - 1) / 2)
|
||||
# easing_function = easing_class(start=self.start_value,
|
||||
# end=self.end_value,
|
||||
# duration=half_ease_duration,
|
||||
# )
|
||||
#
|
||||
# mirror_function = easing_class(start=self.end_value,
|
||||
# end=self.start_value,
|
||||
# duration=half_ease_duration,
|
||||
# )
|
||||
# for step_index in range(num_easing_steps):
|
||||
# if step_index <= half_ease_duration:
|
||||
# step_val = easing_function.ease(step_index)
|
||||
# else:
|
||||
# step_val = mirror_function.ease(step_index - half_ease_duration)
|
||||
# easing_list.append(step_val)
|
||||
# if log_diagnostics: logger.debug(step_index, step_val)
|
||||
#
|
||||
|
||||
else: # no mirroring (default)
|
||||
easing_function = easing_class(start=self.start_value,
|
||||
end=self.end_value,
|
||||
duration=num_easing_steps - 1)
|
||||
for step_index in range(num_easing_steps):
|
||||
step_val = easing_function.ease(step_index)
|
||||
easing_list.append(step_val)
|
||||
if log_diagnostics:
|
||||
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
|
||||
|
||||
if log_diagnostics:
|
||||
logger.debug("prelist size: " + str(len(prelist)))
|
||||
logger.debug("easing_list size: " + str(len(easing_list)))
|
||||
logger.debug("postlist size: " + str(len(postlist)))
|
||||
|
||||
param_list = prelist + easing_list + postlist
|
||||
|
||||
if self.show_easing_plot:
|
||||
plt.figure()
|
||||
plt.xlabel("Step")
|
||||
plt.ylabel("Param Value")
|
||||
plt.title("Per-Step Values Based On Easing: " + self.easing)
|
||||
plt.bar(range(len(param_list)), param_list)
|
||||
# plt.plot(param_list)
|
||||
ax = plt.gca()
|
||||
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format='png')
|
||||
buf.seek(0)
|
||||
im = PIL.Image.open(buf)
|
||||
im.show()
|
||||
buf.close()
|
||||
|
||||
# output array of size steps, each entry list[i] is param value for step i
|
||||
return FloatCollectionOutput(
|
||||
collection=param_list
|
||||
)
|
@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, Union, List
|
||||
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr
|
||||
|
||||
|
||||
@ -47,7 +47,9 @@ class ImageMetadata(BaseModel):
|
||||
default=None, description="The seed used for noise generation."
|
||||
)
|
||||
"""The seed used for noise generation"""
|
||||
cfg_scale: Optional[StrictFloat] = Field(
|
||||
# cfg_scale: Optional[StrictFloat] = Field(
|
||||
# cfg_scale: Union[float, list[float]] = Field(
|
||||
cfg_scale: Union[StrictFloat, List[StrictFloat]] = Field(
|
||||
default=None, description="The classifier-free guidance scale."
|
||||
)
|
||||
"""The classifier-free guidance scale"""
|
||||
|
@ -65,7 +65,6 @@ from typing import Optional, Union, List, get_args
|
||||
def is_union_subtype(t1, t2):
|
||||
t1_args = get_args(t1)
|
||||
t2_args = get_args(t2)
|
||||
|
||||
if not t1_args:
|
||||
# t1 is a single type
|
||||
return t1 in t2_args
|
||||
@ -86,7 +85,6 @@ def is_list_or_contains_list(t):
|
||||
for arg in t_args:
|
||||
if get_origin(arg) is list:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@ -393,7 +391,7 @@ class Graph(BaseModel):
|
||||
from_node = self.get_node(edge.source.node_id)
|
||||
to_node = self.get_node(edge.destination.node_id)
|
||||
except NodeNotFoundError:
|
||||
raise InvalidEdgeError("One or both nodes don't exist")
|
||||
raise InvalidEdgeError("One or both nodes don't exist: {edge.source.node_id} -> {edge.destination.node_id}")
|
||||
|
||||
# Validate that an edge to this node+field doesn't already exist
|
||||
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
||||
@ -404,41 +402,41 @@ class Graph(BaseModel):
|
||||
g = self.nx_graph_flat()
|
||||
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
||||
if not nx.is_directed_acyclic_graph(g):
|
||||
raise InvalidEdgeError(f'Edge creates a cycle in the graph')
|
||||
raise InvalidEdgeError(f'Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}')
|
||||
|
||||
# Validate that the field types are compatible
|
||||
if not are_connections_compatible(
|
||||
from_node, edge.source.field, to_node, edge.destination.field
|
||||
):
|
||||
raise InvalidEdgeError(f'Fields are incompatible')
|
||||
raise InvalidEdgeError(f'Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
|
||||
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
||||
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
raise InvalidEdgeError(f'Iterator input type does not match iterator output type')
|
||||
raise InvalidEdgeError(f'Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
|
||||
# Validate if iterator input type matches output type (if this edge results in both being set)
|
||||
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
raise InvalidEdgeError(f'Iterator output type does not match iterator input type')
|
||||
raise InvalidEdgeError(f'Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
|
||||
# Validate if collector input type matches output type (if this edge results in both being set)
|
||||
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
raise InvalidEdgeError(f'Collector output type does not match collector input type')
|
||||
raise InvalidEdgeError(f'Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
|
||||
# Validate if collector output type matches input type (if this edge results in both being set)
|
||||
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
raise InvalidEdgeError(f'Collector input type does not match collector output type')
|
||||
raise InvalidEdgeError(f'Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
|
||||
|
||||
def has_node(self, node_path: str) -> bool:
|
||||
|
@ -218,7 +218,7 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
||||
class ControlNetData:
|
||||
model: ControlNetModel = Field(default=None)
|
||||
image_tensor: torch.Tensor= Field(default=None)
|
||||
weight: float = Field(default=1.0)
|
||||
weight: Union[float, List[float]]= Field(default=1.0)
|
||||
begin_step_percent: float = Field(default=0.0)
|
||||
end_step_percent: float = Field(default=1.0)
|
||||
|
||||
@ -226,7 +226,7 @@ class ControlNetData:
|
||||
class ConditioningData:
|
||||
unconditioned_embeddings: torch.Tensor
|
||||
text_embeddings: torch.Tensor
|
||||
guidance_scale: float
|
||||
guidance_scale: Union[float, List[float]]
|
||||
"""
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||
@ -662,7 +662,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
down_block_res_samples, mid_block_res_sample = None, None
|
||||
|
||||
if control_data is not None:
|
||||
if conditioning_data.guidance_scale > 1.0:
|
||||
# FIXME: make sure guidance_scale < 1.0 is handled correctly if doing per-step guidance setting
|
||||
# if conditioning_data.guidance_scale > 1.0:
|
||||
if conditioning_data.guidance_scale is not None:
|
||||
# expand the latents input to control model if doing classifier free guidance
|
||||
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
||||
# classifier_free_guidance is <= 1.0 ?)
|
||||
@ -679,13 +681,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# only apply controlnet if current step is within the controlnet's begin/end step range
|
||||
if step_index >= first_control_step and step_index <= last_control_step:
|
||||
# print("running controlnet", i, "for step", step_index)
|
||||
if isinstance(control_datum.weight, list):
|
||||
# if controlnet has multiple weights, use the weight for the current step
|
||||
controlnet_weight = control_datum.weight[step_index]
|
||||
else:
|
||||
# if controlnet has a single weight, use it for all steps
|
||||
controlnet_weight = control_datum.weight
|
||||
down_samples, mid_sample = control_datum.model(
|
||||
sample=latent_control_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
|
||||
conditioning_data.text_embeddings]),
|
||||
controlnet_cond=control_datum.image_tensor,
|
||||
conditioning_scale=control_datum.weight,
|
||||
conditioning_scale=controlnet_weight,
|
||||
# cross_attention_kwargs,
|
||||
guess_mode=False,
|
||||
return_dict=False,
|
||||
|
@ -1,7 +1,7 @@
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
from typing import Any, Callable, Dict, Optional, Union, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -180,7 +180,8 @@ class InvokeAIDiffuserComponent:
|
||||
sigma: torch.Tensor,
|
||||
unconditioning: Union[torch.Tensor, dict],
|
||||
conditioning: Union[torch.Tensor, dict],
|
||||
unconditional_guidance_scale: float,
|
||||
# unconditional_guidance_scale: float,
|
||||
unconditional_guidance_scale: Union[float, List[float]],
|
||||
step_index: Optional[int] = None,
|
||||
total_step_count: Optional[int] = None,
|
||||
**kwargs,
|
||||
@ -195,6 +196,11 @@ class InvokeAIDiffuserComponent:
|
||||
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
|
||||
"""
|
||||
|
||||
if isinstance(unconditional_guidance_scale, list):
|
||||
guidance_scale = unconditional_guidance_scale[step_index]
|
||||
else:
|
||||
guidance_scale = unconditional_guidance_scale
|
||||
|
||||
cross_attention_control_types_to_do = []
|
||||
context: Context = self.cross_attention_control_context
|
||||
if self.cross_attention_control_context is not None:
|
||||
@ -243,7 +249,8 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
|
||||
combined_next_x = self._combine(
|
||||
unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
|
||||
# unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
|
||||
unconditioned_next_x, conditioned_next_x, guidance_scale
|
||||
)
|
||||
|
||||
return combined_next_x
|
||||
@ -497,7 +504,7 @@ class InvokeAIDiffuserComponent:
|
||||
logger.debug(
|
||||
f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}"
|
||||
)
|
||||
logger.debug(
|
||||
logger.debug(
|
||||
f"{outside / latents.numel() * 100:.2f}% values outside threshold"
|
||||
)
|
||||
|
||||
|
@ -18,6 +18,8 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
||||
ColorField: 'color',
|
||||
ControlField: 'control',
|
||||
control: 'control',
|
||||
cfg_scale: 'float',
|
||||
control_weight: 'float',
|
||||
};
|
||||
|
||||
const COLOR_TOKEN_VALUE = 500;
|
||||
|
@ -0,0 +1,33 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
|
||||
/**
|
||||
* Applies HED edge detection to image
|
||||
*/
|
||||
export type HedImageprocessorInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
type?: 'hed_image_processor';
|
||||
/**
|
||||
* image to process
|
||||
*/
|
||||
image?: ImageField;
|
||||
/**
|
||||
* pixel resolution for edge detection
|
||||
*/
|
||||
detect_resolution?: number;
|
||||
/**
|
||||
* pixel resolution for output image
|
||||
*/
|
||||
image_resolution?: number;
|
||||
/**
|
||||
* whether to use scribble mode
|
||||
*/
|
||||
scribble?: boolean;
|
||||
};
|
||||
|
@ -44,6 +44,7 @@ dependencies = [
|
||||
"datasets",
|
||||
"diffusers[torch]~=0.16.1",
|
||||
"dnspython==2.2.1",
|
||||
"easing-functions",
|
||||
"einops",
|
||||
"eventlet",
|
||||
"facexlib",
|
||||
@ -56,6 +57,7 @@ dependencies = [
|
||||
"flaskwebgui==1.0.3",
|
||||
"gfpgan==1.3.8",
|
||||
"huggingface-hub>=0.11.1",
|
||||
"matplotlib", # needed for plotting of Penner easing functions
|
||||
"mediapipe", # needed for "mediapipeface" controlnet model
|
||||
"npyscreen",
|
||||
"numpy<1.24",
|
||||
|
Loading…
Reference in New Issue
Block a user