mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
c3f0d2e273
Unfortunately, this is necessary to prevent circular imports at runtime.
256 lines
11 KiB
Python
256 lines
11 KiB
Python
import io
|
|
from typing import Literal, Optional
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import PIL.Image
|
|
from easing_functions import (
|
|
BackEaseIn,
|
|
BackEaseInOut,
|
|
BackEaseOut,
|
|
BounceEaseIn,
|
|
BounceEaseInOut,
|
|
BounceEaseOut,
|
|
CircularEaseIn,
|
|
CircularEaseInOut,
|
|
CircularEaseOut,
|
|
CubicEaseIn,
|
|
CubicEaseInOut,
|
|
CubicEaseOut,
|
|
ElasticEaseIn,
|
|
ElasticEaseInOut,
|
|
ElasticEaseOut,
|
|
ExponentialEaseIn,
|
|
ExponentialEaseInOut,
|
|
ExponentialEaseOut,
|
|
LinearInOut,
|
|
QuadEaseIn,
|
|
QuadEaseInOut,
|
|
QuadEaseOut,
|
|
QuarticEaseIn,
|
|
QuarticEaseInOut,
|
|
QuarticEaseOut,
|
|
QuinticEaseIn,
|
|
QuinticEaseInOut,
|
|
QuinticEaseOut,
|
|
SineEaseIn,
|
|
SineEaseInOut,
|
|
SineEaseOut,
|
|
)
|
|
from matplotlib.ticker import MaxNLocator
|
|
|
|
from invokeai.app.invocations.primitives import FloatCollectionOutput
|
|
|
|
from .baseinvocation import BaseInvocation, InvocationContext, invocation
|
|
from .fields import InputField
|
|
|
|
|
|
@invocation(
|
|
"float_range",
|
|
title="Float Range",
|
|
tags=["math", "range"],
|
|
category="math",
|
|
version="1.0.0",
|
|
)
|
|
class FloatLinearRangeInvocation(BaseInvocation):
|
|
"""Creates a range"""
|
|
|
|
start: float = InputField(default=5, description="The first value of the range")
|
|
stop: float = InputField(default=10, description="The last value of the range")
|
|
steps: int = InputField(
|
|
default=30,
|
|
description="number of values to interpolate over (including start and stop)",
|
|
)
|
|
|
|
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
|
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
|
return FloatCollectionOutput(collection=param_list)
|
|
|
|
|
|
EASING_FUNCTIONS_MAP = {
|
|
"Linear": LinearInOut,
|
|
"QuadIn": QuadEaseIn,
|
|
"QuadOut": QuadEaseOut,
|
|
"QuadInOut": QuadEaseInOut,
|
|
"CubicIn": CubicEaseIn,
|
|
"CubicOut": CubicEaseOut,
|
|
"CubicInOut": CubicEaseInOut,
|
|
"QuarticIn": QuarticEaseIn,
|
|
"QuarticOut": QuarticEaseOut,
|
|
"QuarticInOut": QuarticEaseInOut,
|
|
"QuinticIn": QuinticEaseIn,
|
|
"QuinticOut": QuinticEaseOut,
|
|
"QuinticInOut": QuinticEaseInOut,
|
|
"SineIn": SineEaseIn,
|
|
"SineOut": SineEaseOut,
|
|
"SineInOut": SineEaseInOut,
|
|
"CircularIn": CircularEaseIn,
|
|
"CircularOut": CircularEaseOut,
|
|
"CircularInOut": CircularEaseInOut,
|
|
"ExponentialIn": ExponentialEaseIn,
|
|
"ExponentialOut": ExponentialEaseOut,
|
|
"ExponentialInOut": ExponentialEaseInOut,
|
|
"ElasticIn": ElasticEaseIn,
|
|
"ElasticOut": ElasticEaseOut,
|
|
"ElasticInOut": ElasticEaseInOut,
|
|
"BackIn": BackEaseIn,
|
|
"BackOut": BackEaseOut,
|
|
"BackInOut": BackEaseInOut,
|
|
"BounceIn": BounceEaseIn,
|
|
"BounceOut": BounceEaseOut,
|
|
"BounceInOut": BounceEaseInOut,
|
|
}
|
|
|
|
EASING_FUNCTION_KEYS = Literal[tuple(EASING_FUNCTIONS_MAP.keys())]
|
|
|
|
|
|
# actually I think for now could just use CollectionOutput (which is list[Any]
|
|
@invocation(
|
|
"step_param_easing",
|
|
title="Step Param Easing",
|
|
tags=["step", "easing"],
|
|
category="step",
|
|
version="1.0.0",
|
|
)
|
|
class StepParamEasingInvocation(BaseInvocation):
|
|
"""Experimental per-step parameter easing for denoising steps"""
|
|
|
|
easing: EASING_FUNCTION_KEYS = InputField(default="Linear", description="The easing function to use")
|
|
num_steps: int = InputField(default=20, description="number of denoising steps")
|
|
start_value: float = InputField(default=0.0, description="easing starting value")
|
|
end_value: float = InputField(default=1.0, description="easing ending value")
|
|
start_step_percent: float = InputField(default=0.0, description="fraction of steps at which to start easing")
|
|
end_step_percent: float = InputField(default=1.0, description="fraction of steps after which to end easing")
|
|
# if None, then start_value is used prior to easing start
|
|
pre_start_value: Optional[float] = InputField(default=None, description="value before easing start")
|
|
# if None, then end value is used prior to easing end
|
|
post_end_value: Optional[float] = InputField(default=None, description="value after easing end")
|
|
mirror: bool = InputField(default=False, description="include mirror of easing function")
|
|
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
|
# alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing")
|
|
show_easing_plot: bool = InputField(default=False, description="show easing plot")
|
|
|
|
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
|
log_diagnostics = False
|
|
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
|
|
# start_step = int(np.floor(self.num_steps * self.start_step_percent))
|
|
start_step = int(np.round(self.num_steps * self.start_step_percent))
|
|
# convert from end_step_percent to nearest step >= (steps * end_step_percent)
|
|
# end_step = int(np.ceil((self.num_steps - 1) * self.end_step_percent))
|
|
end_step = int(np.round((self.num_steps - 1) * self.end_step_percent))
|
|
|
|
# end_step = int(np.ceil(self.num_steps * self.end_step_percent))
|
|
num_easing_steps = end_step - start_step + 1
|
|
|
|
# num_presteps = max(start_step - 1, 0)
|
|
num_presteps = start_step
|
|
num_poststeps = self.num_steps - (num_presteps + num_easing_steps)
|
|
prelist = list(num_presteps * [self.pre_start_value])
|
|
postlist = list(num_poststeps * [self.post_end_value])
|
|
|
|
if log_diagnostics:
|
|
context.services.logger.debug("start_step: " + str(start_step))
|
|
context.services.logger.debug("end_step: " + str(end_step))
|
|
context.services.logger.debug("num_easing_steps: " + str(num_easing_steps))
|
|
context.services.logger.debug("num_presteps: " + str(num_presteps))
|
|
context.services.logger.debug("num_poststeps: " + str(num_poststeps))
|
|
context.services.logger.debug("prelist size: " + str(len(prelist)))
|
|
context.services.logger.debug("postlist size: " + str(len(postlist)))
|
|
context.services.logger.debug("prelist: " + str(prelist))
|
|
context.services.logger.debug("postlist: " + str(postlist))
|
|
|
|
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
|
if log_diagnostics:
|
|
context.services.logger.debug("easing class: " + str(easing_class))
|
|
easing_list = []
|
|
if self.mirror: # "expected" mirroring
|
|
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
|
# and create reverse copy of list to append
|
|
# if number of steps is odd, squeeze duration down to ceil(number_of_steps/2)
|
|
# and create reverse copy of list[1:end-1]
|
|
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
|
|
|
|
base_easing_duration = int(np.ceil(num_easing_steps / 2.0))
|
|
if log_diagnostics:
|
|
context.services.logger.debug("base easing duration: " + str(base_easing_duration))
|
|
even_num_steps = num_easing_steps % 2 == 0 # even number of steps
|
|
easing_function = easing_class(
|
|
start=self.start_value,
|
|
end=self.end_value,
|
|
duration=base_easing_duration - 1,
|
|
)
|
|
base_easing_vals = []
|
|
for step_index in range(base_easing_duration):
|
|
easing_val = easing_function.ease(step_index)
|
|
base_easing_vals.append(easing_val)
|
|
if log_diagnostics:
|
|
context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
|
|
if even_num_steps:
|
|
mirror_easing_vals = list(reversed(base_easing_vals))
|
|
else:
|
|
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
|
|
if log_diagnostics:
|
|
context.services.logger.debug("base easing vals: " + str(base_easing_vals))
|
|
context.services.logger.debug("mirror easing vals: " + str(mirror_easing_vals))
|
|
easing_list = base_easing_vals + mirror_easing_vals
|
|
|
|
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
|
# elif self.alt_mirror: # function mirroring (unintuitive behavior (at least to me))
|
|
# # half_ease_duration = round(num_easing_steps - 1 / 2)
|
|
# half_ease_duration = round((num_easing_steps - 1) / 2)
|
|
# easing_function = easing_class(start=self.start_value,
|
|
# end=self.end_value,
|
|
# duration=half_ease_duration,
|
|
# )
|
|
#
|
|
# mirror_function = easing_class(start=self.end_value,
|
|
# end=self.start_value,
|
|
# duration=half_ease_duration,
|
|
# )
|
|
# for step_index in range(num_easing_steps):
|
|
# if step_index <= half_ease_duration:
|
|
# step_val = easing_function.ease(step_index)
|
|
# else:
|
|
# step_val = mirror_function.ease(step_index - half_ease_duration)
|
|
# easing_list.append(step_val)
|
|
# if log_diagnostics: logger.debug(step_index, step_val)
|
|
#
|
|
|
|
else: # no mirroring (default)
|
|
easing_function = easing_class(
|
|
start=self.start_value,
|
|
end=self.end_value,
|
|
duration=num_easing_steps - 1,
|
|
)
|
|
for step_index in range(num_easing_steps):
|
|
step_val = easing_function.ease(step_index)
|
|
easing_list.append(step_val)
|
|
if log_diagnostics:
|
|
context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
|
|
|
|
if log_diagnostics:
|
|
context.services.logger.debug("prelist size: " + str(len(prelist)))
|
|
context.services.logger.debug("easing_list size: " + str(len(easing_list)))
|
|
context.services.logger.debug("postlist size: " + str(len(postlist)))
|
|
|
|
param_list = prelist + easing_list + postlist
|
|
|
|
if self.show_easing_plot:
|
|
plt.figure()
|
|
plt.xlabel("Step")
|
|
plt.ylabel("Param Value")
|
|
plt.title("Per-Step Values Based On Easing: " + self.easing)
|
|
plt.bar(range(len(param_list)), param_list)
|
|
# plt.plot(param_list)
|
|
ax = plt.gca()
|
|
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
|
buf = io.BytesIO()
|
|
plt.savefig(buf, format="png")
|
|
buf.seek(0)
|
|
im = PIL.Image.open(buf)
|
|
im.show()
|
|
buf.close()
|
|
|
|
# output array of size steps, each entry list[i] is param value for step i
|
|
return FloatCollectionOutput(collection=param_list)
|