mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
238 lines
10 KiB
Python
238 lines
10 KiB
Python
|
import io
|
||
|
from typing import Literal, Optional, Any
|
||
|
|
||
|
# from PIL.Image import Image
|
||
|
import PIL.Image
|
||
|
from matplotlib.ticker import MaxNLocator
|
||
|
from matplotlib.figure import Figure
|
||
|
|
||
|
from pydantic import BaseModel, Field
|
||
|
import numpy as np
|
||
|
import matplotlib.pyplot as plt
|
||
|
|
||
|
from easing_functions import (
|
||
|
LinearInOut,
|
||
|
QuadEaseInOut, QuadEaseIn, QuadEaseOut,
|
||
|
CubicEaseInOut, CubicEaseIn, CubicEaseOut,
|
||
|
QuarticEaseInOut, QuarticEaseIn, QuarticEaseOut,
|
||
|
QuinticEaseInOut, QuinticEaseIn, QuinticEaseOut,
|
||
|
SineEaseInOut, SineEaseIn, SineEaseOut,
|
||
|
CircularEaseIn, CircularEaseInOut, CircularEaseOut,
|
||
|
ExponentialEaseInOut, ExponentialEaseIn, ExponentialEaseOut,
|
||
|
ElasticEaseIn, ElasticEaseInOut, ElasticEaseOut,
|
||
|
BackEaseIn, BackEaseInOut, BackEaseOut,
|
||
|
BounceEaseIn, BounceEaseInOut, BounceEaseOut)
|
||
|
|
||
|
from .baseinvocation import (
|
||
|
BaseInvocation,
|
||
|
BaseInvocationOutput,
|
||
|
InvocationContext,
|
||
|
InvocationConfig,
|
||
|
)
|
||
|
from ...backend.util.logging import InvokeAILogger
|
||
|
from .collections import FloatCollectionOutput
|
||
|
|
||
|
|
||
|
class FloatLinearRangeInvocation(BaseInvocation):
|
||
|
"""Creates a range"""
|
||
|
|
||
|
type: Literal["float_range"] = "float_range"
|
||
|
|
||
|
# Inputs
|
||
|
start: float = Field(default=5, description="The first value of the range")
|
||
|
stop: float = Field(default=10, description="The last value of the range")
|
||
|
steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)")
|
||
|
|
||
|
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||
|
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
||
|
return FloatCollectionOutput(
|
||
|
collection=param_list
|
||
|
)
|
||
|
|
||
|
|
||
|
EASING_FUNCTIONS_MAP = {
|
||
|
"Linear": LinearInOut,
|
||
|
"QuadIn": QuadEaseIn,
|
||
|
"QuadOut": QuadEaseOut,
|
||
|
"QuadInOut": QuadEaseInOut,
|
||
|
"CubicIn": CubicEaseIn,
|
||
|
"CubicOut": CubicEaseOut,
|
||
|
"CubicInOut": CubicEaseInOut,
|
||
|
"QuarticIn": QuarticEaseIn,
|
||
|
"QuarticOut": QuarticEaseOut,
|
||
|
"QuarticInOut": QuarticEaseInOut,
|
||
|
"QuinticIn": QuinticEaseIn,
|
||
|
"QuinticOut": QuinticEaseOut,
|
||
|
"QuinticInOut": QuinticEaseInOut,
|
||
|
"SineIn": SineEaseIn,
|
||
|
"SineOut": SineEaseOut,
|
||
|
"SineInOut": SineEaseInOut,
|
||
|
"CircularIn": CircularEaseIn,
|
||
|
"CircularOut": CircularEaseOut,
|
||
|
"CircularInOut": CircularEaseInOut,
|
||
|
"ExponentialIn": ExponentialEaseIn,
|
||
|
"ExponentialOut": ExponentialEaseOut,
|
||
|
"ExponentialInOut": ExponentialEaseInOut,
|
||
|
"ElasticIn": ElasticEaseIn,
|
||
|
"ElasticOut": ElasticEaseOut,
|
||
|
"ElasticInOut": ElasticEaseInOut,
|
||
|
"BackIn": BackEaseIn,
|
||
|
"BackOut": BackEaseOut,
|
||
|
"BackInOut": BackEaseInOut,
|
||
|
"BounceIn": BounceEaseIn,
|
||
|
"BounceOut": BounceEaseOut,
|
||
|
"BounceInOut": BounceEaseInOut,
|
||
|
}
|
||
|
|
||
|
EASING_FUNCTION_KEYS: Any = Literal[
|
||
|
tuple(list(EASING_FUNCTIONS_MAP.keys()))
|
||
|
]
|
||
|
|
||
|
|
||
|
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||
|
class StepParamEasingInvocation(BaseInvocation):
|
||
|
"""Experimental per-step parameter easing for denoising steps"""
|
||
|
|
||
|
type: Literal["step_param_easing"] = "step_param_easing"
|
||
|
|
||
|
# Inputs
|
||
|
# fmt: off
|
||
|
easing: EASING_FUNCTION_KEYS = Field(default="Linear", description="The easing function to use")
|
||
|
num_steps: int = Field(default=20, description="number of denoising steps")
|
||
|
start_value: float = Field(default=0.0, description="easing starting value")
|
||
|
end_value: float = Field(default=1.0, description="easing ending value")
|
||
|
start_step_percent: float = Field(default=0.0, description="fraction of steps at which to start easing")
|
||
|
end_step_percent: float = Field(default=1.0, description="fraction of steps after which to end easing")
|
||
|
# if None, then start_value is used prior to easing start
|
||
|
pre_start_value: Optional[float] = Field(default=None, description="value before easing start")
|
||
|
# if None, then end value is used prior to easing end
|
||
|
post_end_value: Optional[float] = Field(default=None, description="value after easing end")
|
||
|
mirror: bool = Field(default=False, description="include mirror of easing function")
|
||
|
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||
|
# alt_mirror: bool = Field(default=False, description="alternative mirroring by dual easing")
|
||
|
show_easing_plot: bool = Field(default=False, description="show easing plot")
|
||
|
# fmt: on
|
||
|
|
||
|
|
||
|
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||
|
log_diagnostics = False
|
||
|
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
|
||
|
# start_step = int(np.floor(self.num_steps * self.start_step_percent))
|
||
|
start_step = int(np.round(self.num_steps * self.start_step_percent))
|
||
|
# convert from end_step_percent to nearest step >= (steps * end_step_percent)
|
||
|
# end_step = int(np.ceil((self.num_steps - 1) * self.end_step_percent))
|
||
|
end_step = int(np.round((self.num_steps - 1) * self.end_step_percent))
|
||
|
|
||
|
# end_step = int(np.ceil(self.num_steps * self.end_step_percent))
|
||
|
num_easing_steps = end_step - start_step + 1
|
||
|
|
||
|
# num_presteps = max(start_step - 1, 0)
|
||
|
num_presteps = start_step
|
||
|
num_poststeps = self.num_steps - (num_presteps + num_easing_steps)
|
||
|
prelist = list(num_presteps * [self.pre_start_value])
|
||
|
postlist = list(num_poststeps * [self.post_end_value])
|
||
|
|
||
|
if log_diagnostics:
|
||
|
logger = InvokeAILogger.getLogger(name="StepParamEasing")
|
||
|
logger.debug("start_step: " + str(start_step))
|
||
|
logger.debug("end_step: " + str(end_step))
|
||
|
logger.debug("num_easing_steps: " + str(num_easing_steps))
|
||
|
logger.debug("num_presteps: " + str(num_presteps))
|
||
|
logger.debug("num_poststeps: " + str(num_poststeps))
|
||
|
logger.debug("prelist size: " + str(len(prelist)))
|
||
|
logger.debug("postlist size: " + str(len(postlist)))
|
||
|
logger.debug("prelist: " + str(prelist))
|
||
|
logger.debug("postlist: " + str(postlist))
|
||
|
|
||
|
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
||
|
if log_diagnostics:
|
||
|
logger.debug("easing class: " + str(easing_class))
|
||
|
easing_list = list()
|
||
|
if self.mirror: # "expected" mirroring
|
||
|
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
||
|
# and create reverse copy of list to append
|
||
|
# if number of steps is odd, squeeze duration down to ceil(number_of_steps/2)
|
||
|
# and create reverse copy of list[1:end-1]
|
||
|
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
|
||
|
|
||
|
base_easing_duration = int(np.ceil(num_easing_steps/2.0))
|
||
|
if log_diagnostics: logger.debug("base easing duration: " + str(base_easing_duration))
|
||
|
even_num_steps = (num_easing_steps % 2 == 0) # even number of steps
|
||
|
easing_function = easing_class(start=self.start_value,
|
||
|
end=self.end_value,
|
||
|
duration=base_easing_duration - 1)
|
||
|
base_easing_vals = list()
|
||
|
for step_index in range(base_easing_duration):
|
||
|
easing_val = easing_function.ease(step_index)
|
||
|
base_easing_vals.append(easing_val)
|
||
|
if log_diagnostics:
|
||
|
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
|
||
|
if even_num_steps:
|
||
|
mirror_easing_vals = list(reversed(base_easing_vals))
|
||
|
else:
|
||
|
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
|
||
|
if log_diagnostics:
|
||
|
logger.debug("base easing vals: " + str(base_easing_vals))
|
||
|
logger.debug("mirror easing vals: " + str(mirror_easing_vals))
|
||
|
easing_list = base_easing_vals + mirror_easing_vals
|
||
|
|
||
|
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||
|
# elif self.alt_mirror: # function mirroring (unintuitive behavior (at least to me))
|
||
|
# # half_ease_duration = round(num_easing_steps - 1 / 2)
|
||
|
# half_ease_duration = round((num_easing_steps - 1) / 2)
|
||
|
# easing_function = easing_class(start=self.start_value,
|
||
|
# end=self.end_value,
|
||
|
# duration=half_ease_duration,
|
||
|
# )
|
||
|
#
|
||
|
# mirror_function = easing_class(start=self.end_value,
|
||
|
# end=self.start_value,
|
||
|
# duration=half_ease_duration,
|
||
|
# )
|
||
|
# for step_index in range(num_easing_steps):
|
||
|
# if step_index <= half_ease_duration:
|
||
|
# step_val = easing_function.ease(step_index)
|
||
|
# else:
|
||
|
# step_val = mirror_function.ease(step_index - half_ease_duration)
|
||
|
# easing_list.append(step_val)
|
||
|
# if log_diagnostics: logger.debug(step_index, step_val)
|
||
|
#
|
||
|
|
||
|
else: # no mirroring (default)
|
||
|
easing_function = easing_class(start=self.start_value,
|
||
|
end=self.end_value,
|
||
|
duration=num_easing_steps - 1)
|
||
|
for step_index in range(num_easing_steps):
|
||
|
step_val = easing_function.ease(step_index)
|
||
|
easing_list.append(step_val)
|
||
|
if log_diagnostics:
|
||
|
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
|
||
|
|
||
|
if log_diagnostics:
|
||
|
logger.debug("prelist size: " + str(len(prelist)))
|
||
|
logger.debug("easing_list size: " + str(len(easing_list)))
|
||
|
logger.debug("postlist size: " + str(len(postlist)))
|
||
|
|
||
|
param_list = prelist + easing_list + postlist
|
||
|
|
||
|
if self.show_easing_plot:
|
||
|
plt.figure()
|
||
|
plt.xlabel("Step")
|
||
|
plt.ylabel("Param Value")
|
||
|
plt.title("Per-Step Values Based On Easing: " + self.easing)
|
||
|
plt.bar(range(len(param_list)), param_list)
|
||
|
# plt.plot(param_list)
|
||
|
ax = plt.gca()
|
||
|
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
||
|
buf = io.BytesIO()
|
||
|
plt.savefig(buf, format='png')
|
||
|
buf.seek(0)
|
||
|
im = PIL.Image.open(buf)
|
||
|
im.show()
|
||
|
buf.close()
|
||
|
|
||
|
# output array of size steps, each entry list[i] is param value for step i
|
||
|
return FloatCollectionOutput(
|
||
|
collection=param_list
|
||
|
)
|