diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 15aecde851..b32afe4941 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -1,11 +1,12 @@ # InvokeAI nodes for ControlNet image preprocessors # initial implementation by Gregg Helt, 2023 # heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux +from builtins import float import numpy as np from typing import Literal, Optional, Union, List from PIL import Image, ImageFilter, ImageOps -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, validator from ..models.image import ImageField, ImageCategory, ResourceOrigin from .baseinvocation import ( @@ -14,6 +15,7 @@ from .baseinvocation import ( InvocationContext, InvocationConfig, ) + from controlnet_aux import ( CannyDetector, HEDdetector, @@ -96,15 +98,32 @@ CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)] class ControlField(BaseModel): image: ImageField = Field(default=None, description="The control image") control_model: Optional[str] = Field(default=None, description="The ControlNet model to use") - control_weight: Optional[float] = Field(default=1, description="The weight given to the ControlNet") + # control_weight: Optional[float] = Field(default=1, description="weight given to controlnet") + control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") begin_step_percent: float = Field(default=0, ge=0, le=1, - description="When the ControlNet is first applied (% of total steps)") + description="When the ControlNet is first applied (% of total steps)") end_step_percent: float = Field(default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)") - + @validator("control_weight") + def abs_le_one(cls, v): + """validate that all abs(values) are <=1""" + if isinstance(v, list): + for i in v: + if abs(i) > 1: + raise ValueError('all abs(control_weight) must be <= 1') + else: + if abs(v) > 1: + raise ValueError('abs(control_weight) must be <= 1') + return v class Config: schema_extra = { - "required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"] + "required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"], + "ui": { + "type_hints": { + "control_weight": "float", + # "control_weight": "number", + } + } } @@ -112,7 +131,7 @@ class ControlOutput(BaseInvocationOutput): """node output for ControlNet info""" # fmt: off type: Literal["control_output"] = "control_output" - control: ControlField = Field(default=None, description="The output control image") + control: ControlField = Field(default=None, description="The control info") # fmt: on @@ -123,15 +142,28 @@ class ControlNetInvocation(BaseInvocation): # Inputs image: ImageField = Field(default=None, description="The control image") control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny", - description="The ControlNet model to use") - control_weight: float = Field(default=1.0, ge=0, le=1, description="The weight given to the ControlNet") + description="control model used") + control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet") # TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode begin_step_percent: float = Field(default=0, ge=0, le=1, - description="When the ControlNet is first applied (% of total steps)") + description="When the ControlNet is first applied (% of total steps)") end_step_percent: float = Field(default=1, ge=0, le=1, - description="When the ControlNet is last applied (% of total steps)") + description="When the ControlNet is last applied (% of total steps)") # fmt: on + class Config(InvocationConfig): + schema_extra = { + "ui": { + "tags": ["latents"], + "type_hints": { + "model": "model", + "control": "control", + # "cfg_scale": "float", + "cfg_scale": "number", + "control_weight": "float", + } + }, + } def invoke(self, context: InvocationContext) -> ControlOutput: @@ -161,7 +193,6 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig): return image def invoke(self, context: InvocationContext) -> ImageOutput: - raw_image = context.services.images.get_pil_image( self.image.image_origin, self.image.image_name ) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 814f52f86f..dbd419b6e5 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -174,22 +174,36 @@ class TextToLatentsInvocation(BaseInvocation): negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation") noise: Optional[LatentsField] = Field(description="The noise to use") steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") - cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) + cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) model: str = Field(default="", description="The model to use (currently ignored)") - control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") + control: Union[ControlField, List[ControlField]] = Field(default=None, description="The control to use") # seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) # seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") # fmt: on + @validator("cfg_scale") + def ge_one(cls, v): + """validate that all cfg_scale values are >= 1""" + if isinstance(v, list): + for i in v: + if i < 1: + raise ValueError('cfg_scale must be greater than 1') + else: + if v < 1: + raise ValueError('cfg_scale must be greater than 1') + return v + # Schema customisation class Config(InvocationConfig): schema_extra = { "ui": { - "tags": ["latents", "image"], + "tags": ["latents"], "type_hints": { "model": "model", "control": "control", + # "cfg_scale": "float", + "cfg_scale": "number" } }, } @@ -244,10 +258,10 @@ class TextToLatentsInvocation(BaseInvocation): [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) conditioning_data = ConditioningData( - uc, - c, - self.cfg_scale, - extra_conditioning_info, + unconditioned_embeddings=uc, + text_embeddings=c, + guidance_scale=self.cfg_scale, + extra=extra_conditioning_info, postprocessing_settings=PostprocessingSettings( threshold=0.0,#threshold, warmup=0.2,#warmup, @@ -348,7 +362,8 @@ class TextToLatentsInvocation(BaseInvocation): control_data = self.prep_control_data(model=model, context=context, control_input=self.control, latents_shape=noise.shape, - do_classifier_free_guidance=(self.cfg_scale >= 1.0)) + # do_classifier_free_guidance=(self.cfg_scale >= 1.0)) + do_classifier_free_guidance=True,) # TODO: Verify the noise is the right size result_latents, result_attention_map_saver = model.latents_from_embeddings( @@ -385,6 +400,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): "type_hints": { "model": "model", "control": "control", + "cfg_scale": "number", } }, } @@ -403,10 +419,11 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): model = self.get_model(context.services.model_manager) conditioning_data = self.get_conditioning_data(context, model) - print("type of control input: ", type(self.control)) control_data = self.prep_control_data(model=model, context=context, control_input=self.control, latents_shape=noise.shape, - do_classifier_free_guidance=(self.cfg_scale >= 1.0)) + # do_classifier_free_guidance=(self.cfg_scale >= 1.0)) + do_classifier_free_guidance=True, + ) # TODO: Verify the noise is the right size diff --git a/invokeai/app/invocations/param_easing.py b/invokeai/app/invocations/param_easing.py new file mode 100644 index 0000000000..1ff6261b88 --- /dev/null +++ b/invokeai/app/invocations/param_easing.py @@ -0,0 +1,237 @@ +import io +from typing import Literal, Optional, Any + +# from PIL.Image import Image +import PIL.Image +from matplotlib.ticker import MaxNLocator +from matplotlib.figure import Figure + +from pydantic import BaseModel, Field +import numpy as np +import matplotlib.pyplot as plt + +from easing_functions import ( + LinearInOut, + QuadEaseInOut, QuadEaseIn, QuadEaseOut, + CubicEaseInOut, CubicEaseIn, CubicEaseOut, + QuarticEaseInOut, QuarticEaseIn, QuarticEaseOut, + QuinticEaseInOut, QuinticEaseIn, QuinticEaseOut, + SineEaseInOut, SineEaseIn, SineEaseOut, + CircularEaseIn, CircularEaseInOut, CircularEaseOut, + ExponentialEaseInOut, ExponentialEaseIn, ExponentialEaseOut, + ElasticEaseIn, ElasticEaseInOut, ElasticEaseOut, + BackEaseIn, BackEaseInOut, BackEaseOut, + BounceEaseIn, BounceEaseInOut, BounceEaseOut) + +from .baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + InvocationContext, + InvocationConfig, +) +from ...backend.util.logging import InvokeAILogger +from .collections import FloatCollectionOutput + + +class FloatLinearRangeInvocation(BaseInvocation): + """Creates a range""" + + type: Literal["float_range"] = "float_range" + + # Inputs + start: float = Field(default=5, description="The first value of the range") + stop: float = Field(default=10, description="The last value of the range") + steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)") + + def invoke(self, context: InvocationContext) -> FloatCollectionOutput: + param_list = list(np.linspace(self.start, self.stop, self.steps)) + return FloatCollectionOutput( + collection=param_list + ) + + +EASING_FUNCTIONS_MAP = { + "Linear": LinearInOut, + "QuadIn": QuadEaseIn, + "QuadOut": QuadEaseOut, + "QuadInOut": QuadEaseInOut, + "CubicIn": CubicEaseIn, + "CubicOut": CubicEaseOut, + "CubicInOut": CubicEaseInOut, + "QuarticIn": QuarticEaseIn, + "QuarticOut": QuarticEaseOut, + "QuarticInOut": QuarticEaseInOut, + "QuinticIn": QuinticEaseIn, + "QuinticOut": QuinticEaseOut, + "QuinticInOut": QuinticEaseInOut, + "SineIn": SineEaseIn, + "SineOut": SineEaseOut, + "SineInOut": SineEaseInOut, + "CircularIn": CircularEaseIn, + "CircularOut": CircularEaseOut, + "CircularInOut": CircularEaseInOut, + "ExponentialIn": ExponentialEaseIn, + "ExponentialOut": ExponentialEaseOut, + "ExponentialInOut": ExponentialEaseInOut, + "ElasticIn": ElasticEaseIn, + "ElasticOut": ElasticEaseOut, + "ElasticInOut": ElasticEaseInOut, + "BackIn": BackEaseIn, + "BackOut": BackEaseOut, + "BackInOut": BackEaseInOut, + "BounceIn": BounceEaseIn, + "BounceOut": BounceEaseOut, + "BounceInOut": BounceEaseInOut, +} + +EASING_FUNCTION_KEYS: Any = Literal[ + tuple(list(EASING_FUNCTIONS_MAP.keys())) +] + + +# actually I think for now could just use CollectionOutput (which is list[Any] +class StepParamEasingInvocation(BaseInvocation): + """Experimental per-step parameter easing for denoising steps""" + + type: Literal["step_param_easing"] = "step_param_easing" + + # Inputs + # fmt: off + easing: EASING_FUNCTION_KEYS = Field(default="Linear", description="The easing function to use") + num_steps: int = Field(default=20, description="number of denoising steps") + start_value: float = Field(default=0.0, description="easing starting value") + end_value: float = Field(default=1.0, description="easing ending value") + start_step_percent: float = Field(default=0.0, description="fraction of steps at which to start easing") + end_step_percent: float = Field(default=1.0, description="fraction of steps after which to end easing") + # if None, then start_value is used prior to easing start + pre_start_value: Optional[float] = Field(default=None, description="value before easing start") + # if None, then end value is used prior to easing end + post_end_value: Optional[float] = Field(default=None, description="value after easing end") + mirror: bool = Field(default=False, description="include mirror of easing function") + # FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely + # alt_mirror: bool = Field(default=False, description="alternative mirroring by dual easing") + show_easing_plot: bool = Field(default=False, description="show easing plot") + # fmt: on + + + def invoke(self, context: InvocationContext) -> FloatCollectionOutput: + log_diagnostics = False + # convert from start_step_percent to nearest step <= (steps * start_step_percent) + # start_step = int(np.floor(self.num_steps * self.start_step_percent)) + start_step = int(np.round(self.num_steps * self.start_step_percent)) + # convert from end_step_percent to nearest step >= (steps * end_step_percent) + # end_step = int(np.ceil((self.num_steps - 1) * self.end_step_percent)) + end_step = int(np.round((self.num_steps - 1) * self.end_step_percent)) + + # end_step = int(np.ceil(self.num_steps * self.end_step_percent)) + num_easing_steps = end_step - start_step + 1 + + # num_presteps = max(start_step - 1, 0) + num_presteps = start_step + num_poststeps = self.num_steps - (num_presteps + num_easing_steps) + prelist = list(num_presteps * [self.pre_start_value]) + postlist = list(num_poststeps * [self.post_end_value]) + + if log_diagnostics: + logger = InvokeAILogger.getLogger(name="StepParamEasing") + logger.debug("start_step: " + str(start_step)) + logger.debug("end_step: " + str(end_step)) + logger.debug("num_easing_steps: " + str(num_easing_steps)) + logger.debug("num_presteps: " + str(num_presteps)) + logger.debug("num_poststeps: " + str(num_poststeps)) + logger.debug("prelist size: " + str(len(prelist))) + logger.debug("postlist size: " + str(len(postlist))) + logger.debug("prelist: " + str(prelist)) + logger.debug("postlist: " + str(postlist)) + + easing_class = EASING_FUNCTIONS_MAP[self.easing] + if log_diagnostics: + logger.debug("easing class: " + str(easing_class)) + easing_list = list() + if self.mirror: # "expected" mirroring + # if number of steps is even, squeeze duration down to (number_of_steps)/2 + # and create reverse copy of list to append + # if number of steps is odd, squeeze duration down to ceil(number_of_steps/2) + # and create reverse copy of list[1:end-1] + # but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always + + base_easing_duration = int(np.ceil(num_easing_steps/2.0)) + if log_diagnostics: logger.debug("base easing duration: " + str(base_easing_duration)) + even_num_steps = (num_easing_steps % 2 == 0) # even number of steps + easing_function = easing_class(start=self.start_value, + end=self.end_value, + duration=base_easing_duration - 1) + base_easing_vals = list() + for step_index in range(base_easing_duration): + easing_val = easing_function.ease(step_index) + base_easing_vals.append(easing_val) + if log_diagnostics: + logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val)) + if even_num_steps: + mirror_easing_vals = list(reversed(base_easing_vals)) + else: + mirror_easing_vals = list(reversed(base_easing_vals[0:-1])) + if log_diagnostics: + logger.debug("base easing vals: " + str(base_easing_vals)) + logger.debug("mirror easing vals: " + str(mirror_easing_vals)) + easing_list = base_easing_vals + mirror_easing_vals + + # FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely + # elif self.alt_mirror: # function mirroring (unintuitive behavior (at least to me)) + # # half_ease_duration = round(num_easing_steps - 1 / 2) + # half_ease_duration = round((num_easing_steps - 1) / 2) + # easing_function = easing_class(start=self.start_value, + # end=self.end_value, + # duration=half_ease_duration, + # ) + # + # mirror_function = easing_class(start=self.end_value, + # end=self.start_value, + # duration=half_ease_duration, + # ) + # for step_index in range(num_easing_steps): + # if step_index <= half_ease_duration: + # step_val = easing_function.ease(step_index) + # else: + # step_val = mirror_function.ease(step_index - half_ease_duration) + # easing_list.append(step_val) + # if log_diagnostics: logger.debug(step_index, step_val) + # + + else: # no mirroring (default) + easing_function = easing_class(start=self.start_value, + end=self.end_value, + duration=num_easing_steps - 1) + for step_index in range(num_easing_steps): + step_val = easing_function.ease(step_index) + easing_list.append(step_val) + if log_diagnostics: + logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val)) + + if log_diagnostics: + logger.debug("prelist size: " + str(len(prelist))) + logger.debug("easing_list size: " + str(len(easing_list))) + logger.debug("postlist size: " + str(len(postlist))) + + param_list = prelist + easing_list + postlist + + if self.show_easing_plot: + plt.figure() + plt.xlabel("Step") + plt.ylabel("Param Value") + plt.title("Per-Step Values Based On Easing: " + self.easing) + plt.bar(range(len(param_list)), param_list) + # plt.plot(param_list) + ax = plt.gca() + ax.xaxis.set_major_locator(MaxNLocator(integer=True)) + buf = io.BytesIO() + plt.savefig(buf, format='png') + buf.seek(0) + im = PIL.Image.open(buf) + im.show() + buf.close() + + # output array of size steps, each entry list[i] is param value for step i + return FloatCollectionOutput( + collection=param_list + ) diff --git a/invokeai/app/models/metadata.py b/invokeai/app/models/metadata.py index ac87405423..8d90ca0bc8 100644 --- a/invokeai/app/models/metadata.py +++ b/invokeai/app/models/metadata.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union, List from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr @@ -47,7 +47,9 @@ class ImageMetadata(BaseModel): default=None, description="The seed used for noise generation." ) """The seed used for noise generation""" - cfg_scale: Optional[StrictFloat] = Field( + # cfg_scale: Optional[StrictFloat] = Field( + # cfg_scale: Union[float, list[float]] = Field( + cfg_scale: Union[StrictFloat, List[StrictFloat]] = Field( default=None, description="The classifier-free guidance scale." ) """The classifier-free guidance scale""" diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py index 94be8225da..e3cd3d47ce 100644 --- a/invokeai/app/services/graph.py +++ b/invokeai/app/services/graph.py @@ -65,7 +65,6 @@ from typing import Optional, Union, List, get_args def is_union_subtype(t1, t2): t1_args = get_args(t1) t2_args = get_args(t2) - if not t1_args: # t1 is a single type return t1 in t2_args @@ -86,7 +85,6 @@ def is_list_or_contains_list(t): for arg in t_args: if get_origin(arg) is list: return True - return False @@ -393,7 +391,7 @@ class Graph(BaseModel): from_node = self.get_node(edge.source.node_id) to_node = self.get_node(edge.destination.node_id) except NodeNotFoundError: - raise InvalidEdgeError("One or both nodes don't exist") + raise InvalidEdgeError("One or both nodes don't exist: {edge.source.node_id} -> {edge.destination.node_id}") # Validate that an edge to this node+field doesn't already exist input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field) @@ -404,41 +402,41 @@ class Graph(BaseModel): g = self.nx_graph_flat() g.add_edge(edge.source.node_id, edge.destination.node_id) if not nx.is_directed_acyclic_graph(g): - raise InvalidEdgeError(f'Edge creates a cycle in the graph') + raise InvalidEdgeError(f'Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}') # Validate that the field types are compatible if not are_connections_compatible( from_node, edge.source.field, to_node, edge.destination.field ): - raise InvalidEdgeError(f'Fields are incompatible') + raise InvalidEdgeError(f'Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}') # Validate if iterator output type matches iterator input type (if this edge results in both being set) if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection": if not self._is_iterator_connection_valid( edge.destination.node_id, new_input=edge.source ): - raise InvalidEdgeError(f'Iterator input type does not match iterator output type') + raise InvalidEdgeError(f'Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}') # Validate if iterator input type matches output type (if this edge results in both being set) if isinstance(from_node, IterateInvocation) and edge.source.field == "item": if not self._is_iterator_connection_valid( edge.source.node_id, new_output=edge.destination ): - raise InvalidEdgeError(f'Iterator output type does not match iterator input type') + raise InvalidEdgeError(f'Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}') # Validate if collector input type matches output type (if this edge results in both being set) if isinstance(to_node, CollectInvocation) and edge.destination.field == "item": if not self._is_collector_connection_valid( edge.destination.node_id, new_input=edge.source ): - raise InvalidEdgeError(f'Collector output type does not match collector input type') + raise InvalidEdgeError(f'Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}') # Validate if collector output type matches input type (if this edge results in both being set) if isinstance(from_node, CollectInvocation) and edge.source.field == "collection": if not self._is_collector_connection_valid( edge.source.node_id, new_output=edge.destination ): - raise InvalidEdgeError(f'Collector input type does not match collector output type') + raise InvalidEdgeError(f'Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}') def has_node(self, node_path: str) -> bool: diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 2867630881..6a11891979 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -218,7 +218,7 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]): class ControlNetData: model: ControlNetModel = Field(default=None) image_tensor: torch.Tensor= Field(default=None) - weight: float = Field(default=1.0) + weight: Union[float, List[float]]= Field(default=1.0) begin_step_percent: float = Field(default=0.0) end_step_percent: float = Field(default=1.0) @@ -226,7 +226,7 @@ class ControlNetData: class ConditioningData: unconditioned_embeddings: torch.Tensor text_embeddings: torch.Tensor - guidance_scale: float + guidance_scale: Union[float, List[float]] """ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). @@ -662,7 +662,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): down_block_res_samples, mid_block_res_sample = None, None if control_data is not None: - if conditioning_data.guidance_scale > 1.0: + # FIXME: make sure guidance_scale < 1.0 is handled correctly if doing per-step guidance setting + # if conditioning_data.guidance_scale > 1.0: + if conditioning_data.guidance_scale is not None: # expand the latents input to control model if doing classifier free guidance # (which I think for now is always true, there is conditional elsewhere that stops execution if # classifier_free_guidance is <= 1.0 ?) @@ -679,13 +681,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # only apply controlnet if current step is within the controlnet's begin/end step range if step_index >= first_control_step and step_index <= last_control_step: # print("running controlnet", i, "for step", step_index) + if isinstance(control_datum.weight, list): + # if controlnet has multiple weights, use the weight for the current step + controlnet_weight = control_datum.weight[step_index] + else: + # if controlnet has a single weight, use it for all steps + controlnet_weight = control_datum.weight down_samples, mid_sample = control_datum.model( sample=latent_control_input, timestep=timestep, encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings, conditioning_data.text_embeddings]), controlnet_cond=control_datum.image_tensor, - conditioning_scale=control_datum.weight, + conditioning_scale=controlnet_weight, # cross_attention_kwargs, guess_mode=False, return_dict=False, diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 79043b13f5..eec8097857 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -1,7 +1,7 @@ from contextlib import contextmanager from dataclasses import dataclass from math import ceil -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, Optional, Union, List import numpy as np import torch @@ -180,7 +180,8 @@ class InvokeAIDiffuserComponent: sigma: torch.Tensor, unconditioning: Union[torch.Tensor, dict], conditioning: Union[torch.Tensor, dict], - unconditional_guidance_scale: float, + # unconditional_guidance_scale: float, + unconditional_guidance_scale: Union[float, List[float]], step_index: Optional[int] = None, total_step_count: Optional[int] = None, **kwargs, @@ -195,6 +196,11 @@ class InvokeAIDiffuserComponent: :return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning. """ + if isinstance(unconditional_guidance_scale, list): + guidance_scale = unconditional_guidance_scale[step_index] + else: + guidance_scale = unconditional_guidance_scale + cross_attention_control_types_to_do = [] context: Context = self.cross_attention_control_context if self.cross_attention_control_context is not None: @@ -243,7 +249,8 @@ class InvokeAIDiffuserComponent: ) combined_next_x = self._combine( - unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale + # unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale + unconditioned_next_x, conditioned_next_x, guidance_scale ) return combined_next_x @@ -497,7 +504,7 @@ class InvokeAIDiffuserComponent: logger.debug( f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}" ) - logger.debug( + logger.debug( f"{outside / latents.numel() * 100:.2f}% values outside threshold" ) diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index a9ae209178..e69625778b 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -18,6 +18,8 @@ export const FIELD_TYPE_MAP: Record = { ColorField: 'color', ControlField: 'control', control: 'control', + cfg_scale: 'float', + control_weight: 'float', }; const COLOR_TOKEN_VALUE = 500; diff --git a/invokeai/frontend/web/src/services/api/models/HedImageprocessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/HedImageprocessorInvocation.ts new file mode 100644 index 0000000000..6dea43dc32 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/HedImageprocessorInvocation.ts @@ -0,0 +1,33 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Applies HED edge detection to image + */ +export type HedImageprocessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + type?: 'hed_image_processor'; + /** + * image to process + */ + image?: ImageField; + /** + * pixel resolution for edge detection + */ + detect_resolution?: number; + /** + * pixel resolution for output image + */ + image_resolution?: number; + /** + * whether to use scribble mode + */ + scribble?: boolean; +}; + diff --git a/pyproject.toml b/pyproject.toml index 68c9a64e92..ddf4667eef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ "datasets", "diffusers[torch]~=0.16.1", "dnspython==2.2.1", + "easing-functions", "einops", "eventlet", "facexlib", @@ -56,6 +57,7 @@ dependencies = [ "flaskwebgui==1.0.3", "gfpgan==1.3.8", "huggingface-hub>=0.11.1", + "matplotlib", # needed for plotting of Penner easing functions "mediapipe", # needed for "mediapipeface" controlnet model "npyscreen", "numpy<1.24",