mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
c238a7f18b
Upgrade pydantic and fastapi to latest. - pydantic~=2.4.2 - fastapi~=103.2 - fastapi-events~=0.9.1 **Big Changes** There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes. **Invocations** The biggest change relates to invocation creation, instantiation and validation. Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie. Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`. With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation. This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method. In the end, this implementation is cleaner. **Invocation Fields** In pydantic v2, you can no longer directly add or remove fields from a model. Previously, we did this to add the `type` field to invocations. **Invocation Decorators** With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper. A similar technique is used for `invocation_output()`. **Minor Changes** There are a number of minor changes around the pydantic v2 models API. **Protected `model_` Namespace** All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_". Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple. ```py class IPAdapterModelField(BaseModel): model_name: str = Field(description="Name of the IP-Adapter model") base_model: BaseModelType = Field(description="Base model") model_config = ConfigDict(protected_namespaces=()) ``` **Model Serialization** Pydantic models no longer have `Model.dict()` or `Model.json()`. Instead, we use `Model.model_dump()` or `Model.model_dump_json()`. **Model Deserialization** Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions. Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model. ```py adapter_graph = TypeAdapter(Graph) deserialized_graph_from_json = adapter_graph.validate_json(graph_json) deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict) ``` **Field Customisation** Pydantic `Field`s no longer accept arbitrary args. Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field. **Schema Customisation** FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec. This necessitates two changes: - Our schema customization logic has been revised - Schema parsing to build node templates has been revised The specific aren't important, but this does present additional surface area for bugs. **Performance Improvements** Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node. I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
255 lines
11 KiB
Python
255 lines
11 KiB
Python
import io
|
|
from typing import Literal, Optional
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import PIL.Image
|
|
from easing_functions import (
|
|
BackEaseIn,
|
|
BackEaseInOut,
|
|
BackEaseOut,
|
|
BounceEaseIn,
|
|
BounceEaseInOut,
|
|
BounceEaseOut,
|
|
CircularEaseIn,
|
|
CircularEaseInOut,
|
|
CircularEaseOut,
|
|
CubicEaseIn,
|
|
CubicEaseInOut,
|
|
CubicEaseOut,
|
|
ElasticEaseIn,
|
|
ElasticEaseInOut,
|
|
ElasticEaseOut,
|
|
ExponentialEaseIn,
|
|
ExponentialEaseInOut,
|
|
ExponentialEaseOut,
|
|
LinearInOut,
|
|
QuadEaseIn,
|
|
QuadEaseInOut,
|
|
QuadEaseOut,
|
|
QuarticEaseIn,
|
|
QuarticEaseInOut,
|
|
QuarticEaseOut,
|
|
QuinticEaseIn,
|
|
QuinticEaseInOut,
|
|
QuinticEaseOut,
|
|
SineEaseIn,
|
|
SineEaseInOut,
|
|
SineEaseOut,
|
|
)
|
|
from matplotlib.ticker import MaxNLocator
|
|
|
|
from invokeai.app.invocations.primitives import FloatCollectionOutput
|
|
|
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
|
|
|
|
|
@invocation(
|
|
"float_range",
|
|
title="Float Range",
|
|
tags=["math", "range"],
|
|
category="math",
|
|
version="1.0.0",
|
|
)
|
|
class FloatLinearRangeInvocation(BaseInvocation):
|
|
"""Creates a range"""
|
|
|
|
start: float = InputField(default=5, description="The first value of the range")
|
|
stop: float = InputField(default=10, description="The last value of the range")
|
|
steps: int = InputField(
|
|
default=30,
|
|
description="number of values to interpolate over (including start and stop)",
|
|
)
|
|
|
|
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
|
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
|
return FloatCollectionOutput(collection=param_list)
|
|
|
|
|
|
EASING_FUNCTIONS_MAP = {
|
|
"Linear": LinearInOut,
|
|
"QuadIn": QuadEaseIn,
|
|
"QuadOut": QuadEaseOut,
|
|
"QuadInOut": QuadEaseInOut,
|
|
"CubicIn": CubicEaseIn,
|
|
"CubicOut": CubicEaseOut,
|
|
"CubicInOut": CubicEaseInOut,
|
|
"QuarticIn": QuarticEaseIn,
|
|
"QuarticOut": QuarticEaseOut,
|
|
"QuarticInOut": QuarticEaseInOut,
|
|
"QuinticIn": QuinticEaseIn,
|
|
"QuinticOut": QuinticEaseOut,
|
|
"QuinticInOut": QuinticEaseInOut,
|
|
"SineIn": SineEaseIn,
|
|
"SineOut": SineEaseOut,
|
|
"SineInOut": SineEaseInOut,
|
|
"CircularIn": CircularEaseIn,
|
|
"CircularOut": CircularEaseOut,
|
|
"CircularInOut": CircularEaseInOut,
|
|
"ExponentialIn": ExponentialEaseIn,
|
|
"ExponentialOut": ExponentialEaseOut,
|
|
"ExponentialInOut": ExponentialEaseInOut,
|
|
"ElasticIn": ElasticEaseIn,
|
|
"ElasticOut": ElasticEaseOut,
|
|
"ElasticInOut": ElasticEaseInOut,
|
|
"BackIn": BackEaseIn,
|
|
"BackOut": BackEaseOut,
|
|
"BackInOut": BackEaseInOut,
|
|
"BounceIn": BounceEaseIn,
|
|
"BounceOut": BounceEaseOut,
|
|
"BounceInOut": BounceEaseInOut,
|
|
}
|
|
|
|
EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
|
|
|
|
|
|
# actually I think for now could just use CollectionOutput (which is list[Any]
|
|
@invocation(
|
|
"step_param_easing",
|
|
title="Step Param Easing",
|
|
tags=["step", "easing"],
|
|
category="step",
|
|
version="1.0.0",
|
|
)
|
|
class StepParamEasingInvocation(BaseInvocation):
|
|
"""Experimental per-step parameter easing for denoising steps"""
|
|
|
|
easing: EASING_FUNCTION_KEYS = InputField(default="Linear", description="The easing function to use")
|
|
num_steps: int = InputField(default=20, description="number of denoising steps")
|
|
start_value: float = InputField(default=0.0, description="easing starting value")
|
|
end_value: float = InputField(default=1.0, description="easing ending value")
|
|
start_step_percent: float = InputField(default=0.0, description="fraction of steps at which to start easing")
|
|
end_step_percent: float = InputField(default=1.0, description="fraction of steps after which to end easing")
|
|
# if None, then start_value is used prior to easing start
|
|
pre_start_value: Optional[float] = InputField(default=None, description="value before easing start")
|
|
# if None, then end value is used prior to easing end
|
|
post_end_value: Optional[float] = InputField(default=None, description="value after easing end")
|
|
mirror: bool = InputField(default=False, description="include mirror of easing function")
|
|
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
|
# alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing")
|
|
show_easing_plot: bool = InputField(default=False, description="show easing plot")
|
|
|
|
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
|
log_diagnostics = False
|
|
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
|
|
# start_step = int(np.floor(self.num_steps * self.start_step_percent))
|
|
start_step = int(np.round(self.num_steps * self.start_step_percent))
|
|
# convert from end_step_percent to nearest step >= (steps * end_step_percent)
|
|
# end_step = int(np.ceil((self.num_steps - 1) * self.end_step_percent))
|
|
end_step = int(np.round((self.num_steps - 1) * self.end_step_percent))
|
|
|
|
# end_step = int(np.ceil(self.num_steps * self.end_step_percent))
|
|
num_easing_steps = end_step - start_step + 1
|
|
|
|
# num_presteps = max(start_step - 1, 0)
|
|
num_presteps = start_step
|
|
num_poststeps = self.num_steps - (num_presteps + num_easing_steps)
|
|
prelist = list(num_presteps * [self.pre_start_value])
|
|
postlist = list(num_poststeps * [self.post_end_value])
|
|
|
|
if log_diagnostics:
|
|
context.services.logger.debug("start_step: " + str(start_step))
|
|
context.services.logger.debug("end_step: " + str(end_step))
|
|
context.services.logger.debug("num_easing_steps: " + str(num_easing_steps))
|
|
context.services.logger.debug("num_presteps: " + str(num_presteps))
|
|
context.services.logger.debug("num_poststeps: " + str(num_poststeps))
|
|
context.services.logger.debug("prelist size: " + str(len(prelist)))
|
|
context.services.logger.debug("postlist size: " + str(len(postlist)))
|
|
context.services.logger.debug("prelist: " + str(prelist))
|
|
context.services.logger.debug("postlist: " + str(postlist))
|
|
|
|
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
|
if log_diagnostics:
|
|
context.services.logger.debug("easing class: " + str(easing_class))
|
|
easing_list = list()
|
|
if self.mirror: # "expected" mirroring
|
|
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
|
# and create reverse copy of list to append
|
|
# if number of steps is odd, squeeze duration down to ceil(number_of_steps/2)
|
|
# and create reverse copy of list[1:end-1]
|
|
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
|
|
|
|
base_easing_duration = int(np.ceil(num_easing_steps / 2.0))
|
|
if log_diagnostics:
|
|
context.services.logger.debug("base easing duration: " + str(base_easing_duration))
|
|
even_num_steps = num_easing_steps % 2 == 0 # even number of steps
|
|
easing_function = easing_class(
|
|
start=self.start_value,
|
|
end=self.end_value,
|
|
duration=base_easing_duration - 1,
|
|
)
|
|
base_easing_vals = list()
|
|
for step_index in range(base_easing_duration):
|
|
easing_val = easing_function.ease(step_index)
|
|
base_easing_vals.append(easing_val)
|
|
if log_diagnostics:
|
|
context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
|
|
if even_num_steps:
|
|
mirror_easing_vals = list(reversed(base_easing_vals))
|
|
else:
|
|
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
|
|
if log_diagnostics:
|
|
context.services.logger.debug("base easing vals: " + str(base_easing_vals))
|
|
context.services.logger.debug("mirror easing vals: " + str(mirror_easing_vals))
|
|
easing_list = base_easing_vals + mirror_easing_vals
|
|
|
|
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
|
# elif self.alt_mirror: # function mirroring (unintuitive behavior (at least to me))
|
|
# # half_ease_duration = round(num_easing_steps - 1 / 2)
|
|
# half_ease_duration = round((num_easing_steps - 1) / 2)
|
|
# easing_function = easing_class(start=self.start_value,
|
|
# end=self.end_value,
|
|
# duration=half_ease_duration,
|
|
# )
|
|
#
|
|
# mirror_function = easing_class(start=self.end_value,
|
|
# end=self.start_value,
|
|
# duration=half_ease_duration,
|
|
# )
|
|
# for step_index in range(num_easing_steps):
|
|
# if step_index <= half_ease_duration:
|
|
# step_val = easing_function.ease(step_index)
|
|
# else:
|
|
# step_val = mirror_function.ease(step_index - half_ease_duration)
|
|
# easing_list.append(step_val)
|
|
# if log_diagnostics: logger.debug(step_index, step_val)
|
|
#
|
|
|
|
else: # no mirroring (default)
|
|
easing_function = easing_class(
|
|
start=self.start_value,
|
|
end=self.end_value,
|
|
duration=num_easing_steps - 1,
|
|
)
|
|
for step_index in range(num_easing_steps):
|
|
step_val = easing_function.ease(step_index)
|
|
easing_list.append(step_val)
|
|
if log_diagnostics:
|
|
context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
|
|
|
|
if log_diagnostics:
|
|
context.services.logger.debug("prelist size: " + str(len(prelist)))
|
|
context.services.logger.debug("easing_list size: " + str(len(easing_list)))
|
|
context.services.logger.debug("postlist size: " + str(len(postlist)))
|
|
|
|
param_list = prelist + easing_list + postlist
|
|
|
|
if self.show_easing_plot:
|
|
plt.figure()
|
|
plt.xlabel("Step")
|
|
plt.ylabel("Param Value")
|
|
plt.title("Per-Step Values Based On Easing: " + self.easing)
|
|
plt.bar(range(len(param_list)), param_list)
|
|
# plt.plot(param_list)
|
|
ax = plt.gca()
|
|
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
|
buf = io.BytesIO()
|
|
plt.savefig(buf, format="png")
|
|
buf.seek(0)
|
|
im = PIL.Image.open(buf)
|
|
im.show()
|
|
buf.close()
|
|
|
|
# output array of size steps, each entry list[i] is param value for step i
|
|
return FloatCollectionOutput(collection=param_list)
|