import io from typing import Literal, Optional, Any # from PIL.Image import Image import PIL.Image from matplotlib.ticker import MaxNLocator from matplotlib.figure import Figure from pydantic import BaseModel, Field import numpy as np import matplotlib.pyplot as plt from easing_functions import ( LinearInOut, QuadEaseInOut, QuadEaseIn, QuadEaseOut, CubicEaseInOut, CubicEaseIn, CubicEaseOut, QuarticEaseInOut, QuarticEaseIn, QuarticEaseOut, QuinticEaseInOut, QuinticEaseIn, QuinticEaseOut, SineEaseInOut, SineEaseIn, SineEaseOut, CircularEaseIn, CircularEaseInOut, CircularEaseOut, ExponentialEaseInOut, ExponentialEaseIn, ExponentialEaseOut, ElasticEaseIn, ElasticEaseInOut, ElasticEaseOut, BackEaseIn, BackEaseInOut, BackEaseOut, BounceEaseIn, BounceEaseInOut, BounceEaseOut, ) from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig, ) from ...backend.util.logging import InvokeAILogger from .collections import FloatCollectionOutput class FloatLinearRangeInvocation(BaseInvocation): """Creates a range""" type: Literal["float_range"] = "float_range" # Inputs start: float = Field(default=5, description="The first value of the range") stop: float = Field(default=10, description="The last value of the range") steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)") class Config(InvocationConfig): schema_extra = { "ui": {"title": "Linear Range (Float)", "tags": ["math", "float", "linear", "range"]}, } def invoke(self, context: InvocationContext) -> FloatCollectionOutput: param_list = list(np.linspace(self.start, self.stop, self.steps)) return FloatCollectionOutput(collection=param_list) EASING_FUNCTIONS_MAP = { "Linear": LinearInOut, "QuadIn": QuadEaseIn, "QuadOut": QuadEaseOut, "QuadInOut": QuadEaseInOut, "CubicIn": CubicEaseIn, "CubicOut": CubicEaseOut, "CubicInOut": CubicEaseInOut, "QuarticIn": QuarticEaseIn, "QuarticOut": QuarticEaseOut, "QuarticInOut": QuarticEaseInOut, "QuinticIn": QuinticEaseIn, "QuinticOut": QuinticEaseOut, "QuinticInOut": QuinticEaseInOut, "SineIn": SineEaseIn, "SineOut": SineEaseOut, "SineInOut": SineEaseInOut, "CircularIn": CircularEaseIn, "CircularOut": CircularEaseOut, "CircularInOut": CircularEaseInOut, "ExponentialIn": ExponentialEaseIn, "ExponentialOut": ExponentialEaseOut, "ExponentialInOut": ExponentialEaseInOut, "ElasticIn": ElasticEaseIn, "ElasticOut": ElasticEaseOut, "ElasticInOut": ElasticEaseInOut, "BackIn": BackEaseIn, "BackOut": BackEaseOut, "BackInOut": BackEaseInOut, "BounceIn": BounceEaseIn, "BounceOut": BounceEaseOut, "BounceInOut": BounceEaseInOut, } EASING_FUNCTION_KEYS: Any = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))] # actually I think for now could just use CollectionOutput (which is list[Any] class StepParamEasingInvocation(BaseInvocation): """Experimental per-step parameter easing for denoising steps""" type: Literal["step_param_easing"] = "step_param_easing" # Inputs # fmt: off easing: EASING_FUNCTION_KEYS = Field(default="Linear", description="The easing function to use") num_steps: int = Field(default=20, description="number of denoising steps") start_value: float = Field(default=0.0, description="easing starting value") end_value: float = Field(default=1.0, description="easing ending value") start_step_percent: float = Field(default=0.0, description="fraction of steps at which to start easing") end_step_percent: float = Field(default=1.0, description="fraction of steps after which to end easing") # if None, then start_value is used prior to easing start pre_start_value: Optional[float] = Field(default=None, description="value before easing start") # if None, then end value is used prior to easing end post_end_value: Optional[float] = Field(default=None, description="value after easing end") mirror: bool = Field(default=False, description="include mirror of easing function") # FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely # alt_mirror: bool = Field(default=False, description="alternative mirroring by dual easing") show_easing_plot: bool = Field(default=False, description="show easing plot") # fmt: on class Config(InvocationConfig): schema_extra = { "ui": {"title": "Param Easing By Step", "tags": ["param", "step", "easing"]}, } def invoke(self, context: InvocationContext) -> FloatCollectionOutput: log_diagnostics = False # convert from start_step_percent to nearest step <= (steps * start_step_percent) # start_step = int(np.floor(self.num_steps * self.start_step_percent)) start_step = int(np.round(self.num_steps * self.start_step_percent)) # convert from end_step_percent to nearest step >= (steps * end_step_percent) # end_step = int(np.ceil((self.num_steps - 1) * self.end_step_percent)) end_step = int(np.round((self.num_steps - 1) * self.end_step_percent)) # end_step = int(np.ceil(self.num_steps * self.end_step_percent)) num_easing_steps = end_step - start_step + 1 # num_presteps = max(start_step - 1, 0) num_presteps = start_step num_poststeps = self.num_steps - (num_presteps + num_easing_steps) prelist = list(num_presteps * [self.pre_start_value]) postlist = list(num_poststeps * [self.post_end_value]) if log_diagnostics: context.services.logger.debug("start_step: " + str(start_step)) context.services.logger.debug("end_step: " + str(end_step)) context.services.logger.debug("num_easing_steps: " + str(num_easing_steps)) context.services.logger.debug("num_presteps: " + str(num_presteps)) context.services.logger.debug("num_poststeps: " + str(num_poststeps)) context.services.logger.debug("prelist size: " + str(len(prelist))) context.services.logger.debug("postlist size: " + str(len(postlist))) context.services.logger.debug("prelist: " + str(prelist)) context.services.logger.debug("postlist: " + str(postlist)) easing_class = EASING_FUNCTIONS_MAP[self.easing] if log_diagnostics: context.services.logger.debug("easing class: " + str(easing_class)) easing_list = list() if self.mirror: # "expected" mirroring # if number of steps is even, squeeze duration down to (number_of_steps)/2 # and create reverse copy of list to append # if number of steps is odd, squeeze duration down to ceil(number_of_steps/2) # and create reverse copy of list[1:end-1] # but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always base_easing_duration = int(np.ceil(num_easing_steps / 2.0)) if log_diagnostics: context.services.logger.debug("base easing duration: " + str(base_easing_duration)) even_num_steps = num_easing_steps % 2 == 0 # even number of steps easing_function = easing_class( start=self.start_value, end=self.end_value, duration=base_easing_duration - 1 ) base_easing_vals = list() for step_index in range(base_easing_duration): easing_val = easing_function.ease(step_index) base_easing_vals.append(easing_val) if log_diagnostics: context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val)) if even_num_steps: mirror_easing_vals = list(reversed(base_easing_vals)) else: mirror_easing_vals = list(reversed(base_easing_vals[0:-1])) if log_diagnostics: context.services.logger.debug("base easing vals: " + str(base_easing_vals)) context.services.logger.debug("mirror easing vals: " + str(mirror_easing_vals)) easing_list = base_easing_vals + mirror_easing_vals # FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely # elif self.alt_mirror: # function mirroring (unintuitive behavior (at least to me)) # # half_ease_duration = round(num_easing_steps - 1 / 2) # half_ease_duration = round((num_easing_steps - 1) / 2) # easing_function = easing_class(start=self.start_value, # end=self.end_value, # duration=half_ease_duration, # ) # # mirror_function = easing_class(start=self.end_value, # end=self.start_value, # duration=half_ease_duration, # ) # for step_index in range(num_easing_steps): # if step_index <= half_ease_duration: # step_val = easing_function.ease(step_index) # else: # step_val = mirror_function.ease(step_index - half_ease_duration) # easing_list.append(step_val) # if log_diagnostics: logger.debug(step_index, step_val) # else: # no mirroring (default) easing_function = easing_class(start=self.start_value, end=self.end_value, duration=num_easing_steps - 1) for step_index in range(num_easing_steps): step_val = easing_function.ease(step_index) easing_list.append(step_val) if log_diagnostics: context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val)) if log_diagnostics: context.services.logger.debug("prelist size: " + str(len(prelist))) context.services.logger.debug("easing_list size: " + str(len(easing_list))) context.services.logger.debug("postlist size: " + str(len(postlist))) param_list = prelist + easing_list + postlist if self.show_easing_plot: plt.figure() plt.xlabel("Step") plt.ylabel("Param Value") plt.title("Per-Step Values Based On Easing: " + self.easing) plt.bar(range(len(param_list)), param_list) # plt.plot(param_list) ax = plt.gca() ax.xaxis.set_major_locator(MaxNLocator(integer=True)) buf = io.BytesIO() plt.savefig(buf, format="png") buf.seek(0) im = PIL.Image.open(buf) im.show() buf.close() # output array of size steps, each entry list[i] is param value for step i return FloatCollectionOutput(collection=param_list)