Merge branch 'main' into diffusers-upgrade

This commit is contained in:
blessedcoolant
2023-06-13 05:29:15 +12:00
committed by GitHub
55 changed files with 1277 additions and 361 deletions

View File

@ -218,7 +218,7 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
class ControlNetData:
model: ControlNetModel = Field(default=None)
image_tensor: torch.Tensor= Field(default=None)
weight: float = Field(default=1.0)
weight: Union[float, List[float]]= Field(default=1.0)
begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0)
@ -226,7 +226,7 @@ class ControlNetData:
class ConditioningData:
unconditioned_embeddings: torch.Tensor
text_embeddings: torch.Tensor
guidance_scale: float
guidance_scale: Union[float, List[float]]
"""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
@ -662,7 +662,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
down_block_res_samples, mid_block_res_sample = None, None
if control_data is not None:
if conditioning_data.guidance_scale > 1.0:
# FIXME: make sure guidance_scale < 1.0 is handled correctly if doing per-step guidance setting
# if conditioning_data.guidance_scale > 1.0:
if conditioning_data.guidance_scale is not None:
# expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if
# classifier_free_guidance is <= 1.0 ?)
@ -679,13 +681,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# only apply controlnet if current step is within the controlnet's begin/end step range
if step_index >= first_control_step and step_index <= last_control_step:
# print("running controlnet", i, "for step", step_index)
if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step
controlnet_weight = control_datum.weight[step_index]
else:
# if controlnet has a single weight, use it for all steps
controlnet_weight = control_datum.weight
down_samples, mid_sample = control_datum.model(
sample=latent_control_input,
timestep=timestep,
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings]),
controlnet_cond=control_datum.image_tensor,
conditioning_scale=control_datum.weight,
conditioning_scale=controlnet_weight,
# cross_attention_kwargs,
guess_mode=False,
return_dict=False,

View File

@ -1,7 +1,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from math import ceil
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Union, List
import numpy as np
import torch
@ -180,7 +180,8 @@ class InvokeAIDiffuserComponent:
sigma: torch.Tensor,
unconditioning: Union[torch.Tensor, dict],
conditioning: Union[torch.Tensor, dict],
unconditional_guidance_scale: float,
# unconditional_guidance_scale: float,
unconditional_guidance_scale: Union[float, List[float]],
step_index: Optional[int] = None,
total_step_count: Optional[int] = None,
**kwargs,
@ -195,6 +196,11 @@ class InvokeAIDiffuserComponent:
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
"""
if isinstance(unconditional_guidance_scale, list):
guidance_scale = unconditional_guidance_scale[step_index]
else:
guidance_scale = unconditional_guidance_scale
cross_attention_control_types_to_do = []
context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None:
@ -243,7 +249,8 @@ class InvokeAIDiffuserComponent:
)
combined_next_x = self._combine(
unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
# unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
unconditioned_next_x, conditioned_next_x, guidance_scale
)
return combined_next_x
@ -497,7 +504,7 @@ class InvokeAIDiffuserComponent:
logger.debug(
f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}"
)
logger.debug(
logger.debug(
f"{outside / latents.numel() * 100:.2f}% values outside threshold"
)