mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Apply black
This commit is contained in:
@ -47,6 +47,7 @@ from .diffusion import (
|
||||
)
|
||||
from .offloading import FullyLoadedModelGroup, ModelGroup
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineIntermediateState:
|
||||
run_id: str
|
||||
@ -72,7 +73,11 @@ class AddsMaskLatents:
|
||||
initial_image_latents: torch.Tensor
|
||||
|
||||
def __call__(
|
||||
self, latents: torch.Tensor, t: torch.Tensor, text_embeddings: torch.Tensor, **kwargs,
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
text_embeddings: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
model_input = self.add_mask_channels(latents)
|
||||
return self.forward(model_input, t, text_embeddings, **kwargs)
|
||||
@ -80,12 +85,8 @@ class AddsMaskLatents:
|
||||
def add_mask_channels(self, latents):
|
||||
batch_size = latents.size(0)
|
||||
# duplicate mask and latents for each batch
|
||||
mask = einops.repeat(
|
||||
self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size
|
||||
)
|
||||
image_latents = einops.repeat(
|
||||
self.initial_image_latents, "b c h w -> (repeat b) c h w", repeat=batch_size
|
||||
)
|
||||
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||
image_latents = einops.repeat(self.initial_image_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||
# add mask and image as additional channels
|
||||
model_input, _ = einops.pack([latents, mask, image_latents], "b * h w")
|
||||
return model_input
|
||||
@ -103,9 +104,7 @@ class AddsMaskGuidance:
|
||||
noise: torch.Tensor
|
||||
_debug: Optional[Callable] = None
|
||||
|
||||
def __call__(
|
||||
self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning
|
||||
) -> BaseOutput:
|
||||
def __call__(self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning) -> BaseOutput:
|
||||
output_class = step_output.__class__ # We'll create a new one with masked data.
|
||||
|
||||
# The problem with taking SchedulerOutput instead of the model output is that we're less certain what's in it.
|
||||
@ -116,11 +115,7 @@ class AddsMaskGuidance:
|
||||
# Mask anything that has the same shape as prev_sample, return others as-is.
|
||||
return output_class(
|
||||
{
|
||||
k: (
|
||||
self.apply_mask(v, self._t_for_field(k, t))
|
||||
if are_like_tensors(prev_sample, v)
|
||||
else v
|
||||
)
|
||||
k: (self.apply_mask(v, self._t_for_field(k, t)) if are_like_tensors(prev_sample, v) else v)
|
||||
for k, v in step_output.items()
|
||||
}
|
||||
)
|
||||
@ -132,9 +127,7 @@ class AddsMaskGuidance:
|
||||
|
||||
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
|
||||
batch_size = latents.size(0)
|
||||
mask = einops.repeat(
|
||||
self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size
|
||||
)
|
||||
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||
if t.dim() == 0:
|
||||
# some schedulers expect t to be one-dimensional.
|
||||
# TODO: file diffusers bug about inconsistency?
|
||||
@ -144,12 +137,8 @@ class AddsMaskGuidance:
|
||||
mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t)
|
||||
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
||||
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
||||
mask_latents = einops.repeat(
|
||||
mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size
|
||||
)
|
||||
masked_input = torch.lerp(
|
||||
mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)
|
||||
)
|
||||
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
|
||||
if self._debug:
|
||||
self._debug(masked_input, f"t={t} lerped")
|
||||
return masked_input
|
||||
@ -159,9 +148,7 @@ def trim_to_multiple_of(*args, multiple_of=8):
|
||||
return tuple((x - x % multiple_of) for x in args)
|
||||
|
||||
|
||||
def image_resized_to_grid_as_tensor(
|
||||
image: PIL.Image.Image, normalize: bool = True, multiple_of=8
|
||||
) -> torch.FloatTensor:
|
||||
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool = True, multiple_of=8) -> torch.FloatTensor:
|
||||
"""
|
||||
|
||||
:param image: input image
|
||||
@ -211,6 +198,7 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
||||
raise AssertionError("why was that an empty generator?")
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlNetData:
|
||||
model: ControlNetModel = Field(default=None)
|
||||
@ -341,9 +329,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# FIXME: can't currently register control module
|
||||
# control_model=control_model,
|
||||
)
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(
|
||||
self.unet, self._unet_forward
|
||||
)
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
|
||||
|
||||
self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device)
|
||||
self._model_group.install(*self._submodels)
|
||||
@ -354,11 +340,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if xformers is available, use it, otherwise use sliced attention.
|
||||
"""
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and is_xformers_available()
|
||||
and not config.disable_xformers
|
||||
):
|
||||
if torch.cuda.is_available() and is_xformers_available() and not config.disable_xformers:
|
||||
self.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
if self.device.type == "cpu" or self.device.type == "mps":
|
||||
@ -369,9 +351,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
raise ValueError(f"unrecognized device {self.device}")
|
||||
# input tensor of [1, 4, h/8, w/8]
|
||||
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
|
||||
bytes_per_element_needed_for_baddbmm_duplication = (
|
||||
latents.element_size() + 4
|
||||
)
|
||||
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
|
||||
max_size_required_for_baddbmm = (
|
||||
16
|
||||
* latents.size(dim=2)
|
||||
@ -380,9 +360,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
* latents.size(dim=3)
|
||||
* bytes_per_element_needed_for_baddbmm_duplication
|
||||
)
|
||||
if max_size_required_for_baddbmm > (
|
||||
mem_free * 3.0 / 4.0
|
||||
): # 3.3 / 4.0 is from old Invoke code
|
||||
if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code
|
||||
self.enable_attention_slicing(slice_size="max")
|
||||
elif torch.backends.mps.is_available():
|
||||
# diffusers recommends always enabling for mps
|
||||
@ -470,7 +448,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
control_data: List[ControlNetData] = None,
|
||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||
if self.scheduler.config.get("cpu_only", False):
|
||||
scheduler_device = torch.device('cpu')
|
||||
scheduler_device = torch.device("cpu")
|
||||
else:
|
||||
scheduler_device = self._model_group.device_for(self.unet)
|
||||
|
||||
@ -488,7 +466,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
run_id=run_id,
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
|
||||
callback=callback,
|
||||
)
|
||||
return result.latents, result.attention_map_saver
|
||||
@ -511,9 +488,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
additional_guidance = []
|
||||
extra_conditioning_info = conditioning_data.extra
|
||||
with self.invokeai_diffuser.custom_attention_context(
|
||||
self.invokeai_diffuser.model,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps),
|
||||
self.invokeai_diffuser.model,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps),
|
||||
):
|
||||
yield PipelineIntermediateState(
|
||||
run_id=run_id,
|
||||
@ -607,16 +584,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# that are combined at higher level to make control_mode enum
|
||||
# soft_injection determines whether to do per-layer re-weighting adjustment (if True)
|
||||
# or default weighting (if False)
|
||||
soft_injection = (control_mode == "more_prompt" or control_mode == "more_control")
|
||||
soft_injection = control_mode == "more_prompt" or control_mode == "more_control"
|
||||
# cfg_injection = determines whether to apply ControlNet to only the conditional (if True)
|
||||
# or the default both conditional and unconditional (if False)
|
||||
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
||||
cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"
|
||||
|
||||
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
|
||||
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
||||
# only apply controlnet if current step is within the controlnet's begin/end step range
|
||||
if step_index >= first_control_step and step_index <= last_control_step:
|
||||
|
||||
if cfg_injection:
|
||||
control_latent_input = unet_latent_input
|
||||
else:
|
||||
@ -629,7 +605,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
encoder_hidden_states = conditioning_data.text_embeddings
|
||||
encoder_attention_mask = None
|
||||
else:
|
||||
encoder_hidden_states, encoder_attention_mask = self.invokeai_diffuser._concat_conditionings_for_batch(
|
||||
(
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = self.invokeai_diffuser._concat_conditionings_for_batch(
|
||||
conditioning_data.unconditioned_embeddings,
|
||||
conditioning_data.text_embeddings,
|
||||
)
|
||||
@ -646,9 +625,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
controlnet_cond=control_datum.image_tensor,
|
||||
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
|
||||
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||
return_dict=False,
|
||||
)
|
||||
if cfg_injection:
|
||||
@ -678,13 +657,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
||||
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
|
||||
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
step_output = self.scheduler.step(
|
||||
noise_pred, timestep, latents, **conditioning_data.scheduler_args
|
||||
)
|
||||
step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args)
|
||||
|
||||
# TODO: this additional_guidance extension point feels redundant with InvokeAIDiffusionComponent.
|
||||
# But the way things are now, scheduler runs _after_ that, so there was
|
||||
@ -710,17 +687,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# use of AddsMaskLatents.
|
||||
latents = AddsMaskLatents(
|
||||
self._unet_forward,
|
||||
mask=torch.ones_like(
|
||||
latents[:1, :1], device=latents.device, dtype=latents.dtype
|
||||
),
|
||||
initial_image_latents=torch.zeros_like(
|
||||
latents[:1], device=latents.device, dtype=latents.dtype
|
||||
),
|
||||
mask=torch.ones_like(latents[:1, :1], device=latents.device, dtype=latents.dtype),
|
||||
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype),
|
||||
).add_mask_channels(latents)
|
||||
|
||||
# First three args should be positional, not keywords, so torch hooks can see them.
|
||||
return self.unet(
|
||||
latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs,
|
||||
latents,
|
||||
t,
|
||||
text_embeddings,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
**kwargs,
|
||||
).sample
|
||||
|
||||
@ -774,9 +750,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
|
||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||
latents=initial_latents if strength < 1.0 else torch.zeros_like(
|
||||
initial_latents, device=initial_latents.device, dtype=initial_latents.dtype
|
||||
),
|
||||
latents=initial_latents
|
||||
if strength < 1.0
|
||||
else torch.zeros_like(initial_latents, device=initial_latents.device, dtype=initial_latents.dtype),
|
||||
num_inference_steps=num_inference_steps,
|
||||
conditioning_data=conditioning_data,
|
||||
timesteps=timesteps,
|
||||
@ -797,14 +773,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
)
|
||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||
|
||||
def get_img2img_timesteps(
|
||||
self, num_inference_steps: int, strength: float, device=None
|
||||
) -> (torch.Tensor, int):
|
||||
def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device=None) -> (torch.Tensor, int):
|
||||
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
||||
assert img2img_pipeline.scheduler is self.scheduler
|
||||
|
||||
if self.scheduler.config.get("cpu_only", False):
|
||||
scheduler_device = torch.device('cpu')
|
||||
scheduler_device = torch.device("cpu")
|
||||
else:
|
||||
scheduler_device = self._model_group.device_for(self.unet)
|
||||
|
||||
@ -849,18 +823,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# 6. Prepare latent variables
|
||||
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
|
||||
# because we have our own noise function
|
||||
init_image_latents = self.non_noised_latents_from_image(
|
||||
init_image, device=device, dtype=latents_dtype
|
||||
)
|
||||
init_image_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
|
||||
if seed is not None:
|
||||
set_seed(seed)
|
||||
noise = noise_func(init_image_latents)
|
||||
|
||||
if mask.dim() == 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
latent_mask = tv_resize(
|
||||
mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR
|
||||
).to(device=device, dtype=latents_dtype)
|
||||
latent_mask = tv_resize(mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR).to(
|
||||
device=device, dtype=latents_dtype
|
||||
)
|
||||
|
||||
guidance: List[Callable] = []
|
||||
|
||||
@ -868,22 +840,20 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint
|
||||
# (that's why there's a mask!) but it seems to really want that blanked out.
|
||||
masked_init_image = init_image * torch.where(mask < 0.5, 1, 0)
|
||||
masked_latents = self.non_noised_latents_from_image(
|
||||
masked_init_image, device=device, dtype=latents_dtype
|
||||
)
|
||||
masked_latents = self.non_noised_latents_from_image(masked_init_image, device=device, dtype=latents_dtype)
|
||||
|
||||
# TODO: we should probably pass this in so we don't have to try/finally around setting it.
|
||||
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(
|
||||
self._unet_forward, latent_mask, masked_latents
|
||||
)
|
||||
else:
|
||||
guidance.append(
|
||||
AddsMaskGuidance(latent_mask, init_image_latents, self.scheduler, noise)
|
||||
)
|
||||
guidance.append(AddsMaskGuidance(latent_mask, init_image_latents, self.scheduler, noise))
|
||||
|
||||
try:
|
||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||
latents=init_image_latents if strength < 1.0 else torch.zeros_like(
|
||||
latents=init_image_latents
|
||||
if strength < 1.0
|
||||
else torch.zeros_like(
|
||||
init_image_latents, device=init_image_latents.device, dtype=init_image_latents.dtype
|
||||
),
|
||||
num_inference_steps=num_inference_steps,
|
||||
@ -914,18 +884,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
with torch.inference_mode():
|
||||
self._model_group.load(self.vae)
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample().to(
|
||||
dtype=dtype
|
||||
) # FIXME: uses torch.randn. make reproducible!
|
||||
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
|
||||
|
||||
init_latents = 0.18215 * init_latents
|
||||
return init_latents
|
||||
|
||||
def check_for_safety(self, output, dtype):
|
||||
with torch.inference_mode():
|
||||
screened_images, has_nsfw_concept = self.run_safety_checker(
|
||||
output.images, dtype=dtype
|
||||
)
|
||||
screened_images, has_nsfw_concept = self.run_safety_checker(output.images, dtype=dtype)
|
||||
screened_attention_map_saver = None
|
||||
if has_nsfw_concept is None or not has_nsfw_concept:
|
||||
screened_attention_map_saver = output.attention_map_saver
|
||||
@ -949,9 +915,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
def debug_latents(self, latents, msg):
|
||||
from invokeai.backend.image_util import debug_image
|
||||
|
||||
with torch.inference_mode():
|
||||
decoded = self.numpy_to_pil(self.decode_latents(latents))
|
||||
for i, img in enumerate(decoded):
|
||||
debug_image(
|
||||
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
|
||||
)
|
||||
debug_image(img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True)
|
||||
|
@ -17,6 +17,7 @@ from torch import nn
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ...util import torch_dtype
|
||||
|
||||
|
||||
class CrossAttentionType(enum.Enum):
|
||||
SELF = 1
|
||||
TOKENS = 2
|
||||
@ -55,9 +56,7 @@ class Context:
|
||||
if name in self.self_cross_attention_module_identifiers:
|
||||
assert False, f"name {name} cannot appear more than once"
|
||||
self.self_cross_attention_module_identifiers.append(name)
|
||||
for name, module in get_cross_attention_modules(
|
||||
model, CrossAttentionType.TOKENS
|
||||
):
|
||||
for name, module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
|
||||
if name in self.tokens_cross_attention_module_identifiers:
|
||||
assert False, f"name {name} cannot appear more than once"
|
||||
self.tokens_cross_attention_module_identifiers.append(name)
|
||||
@ -68,9 +67,7 @@ class Context:
|
||||
else:
|
||||
self.tokens_cross_attention_action = Context.Action.SAVE
|
||||
|
||||
def request_apply_saved_attention_maps(
|
||||
self, cross_attention_type: CrossAttentionType
|
||||
):
|
||||
def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||
if cross_attention_type == CrossAttentionType.SELF:
|
||||
self.self_cross_attention_action = Context.Action.APPLY
|
||||
else:
|
||||
@ -139,9 +136,7 @@ class Context:
|
||||
saved_attention_dict = self.saved_cross_attention_maps[identifier]
|
||||
if requested_dim is None:
|
||||
if saved_attention_dict["dim"] is not None:
|
||||
raise RuntimeError(
|
||||
f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}"
|
||||
)
|
||||
raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}")
|
||||
return saved_attention_dict["slices"][0]
|
||||
|
||||
if saved_attention_dict["dim"] == requested_dim:
|
||||
@ -154,21 +149,13 @@ class Context:
|
||||
if saved_attention_dict["dim"] is None:
|
||||
whole_saved_attention = saved_attention_dict["slices"][0]
|
||||
if requested_dim == 0:
|
||||
return whole_saved_attention[
|
||||
requested_offset : requested_offset + slice_size
|
||||
]
|
||||
return whole_saved_attention[requested_offset : requested_offset + slice_size]
|
||||
elif requested_dim == 1:
|
||||
return whole_saved_attention[
|
||||
:, requested_offset : requested_offset + slice_size
|
||||
]
|
||||
return whole_saved_attention[:, requested_offset : requested_offset + slice_size]
|
||||
|
||||
raise RuntimeError(
|
||||
f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}"
|
||||
)
|
||||
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
|
||||
|
||||
def get_slicing_strategy(
|
||||
self, identifier: str
|
||||
) -> tuple[Optional[int], Optional[int]]:
|
||||
def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]:
|
||||
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
|
||||
if saved_attention is None:
|
||||
return None, None
|
||||
@ -201,9 +188,7 @@ class InvokeAICrossAttentionMixin:
|
||||
|
||||
def set_attention_slice_wrangler(
|
||||
self,
|
||||
wrangler: Optional[
|
||||
Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]
|
||||
],
|
||||
wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]],
|
||||
):
|
||||
"""
|
||||
Set custom attention calculator to be called when attention is calculated
|
||||
@ -219,14 +204,10 @@ class InvokeAICrossAttentionMixin:
|
||||
"""
|
||||
self.attention_slice_wrangler = wrangler
|
||||
|
||||
def set_slicing_strategy_getter(
|
||||
self, getter: Optional[Callable[[nn.Module], tuple[int, int]]]
|
||||
):
|
||||
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int, int]]]):
|
||||
self.slicing_strategy_getter = getter
|
||||
|
||||
def set_attention_slice_calculated_callback(
|
||||
self, callback: Optional[Callable[[torch.Tensor], None]]
|
||||
):
|
||||
def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]):
|
||||
self.attention_slice_calculated_callback = callback
|
||||
|
||||
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
|
||||
@ -247,45 +228,31 @@ class InvokeAICrossAttentionMixin:
|
||||
)
|
||||
|
||||
# calculate attention slice by taking the best scores for each latent pixel
|
||||
default_attention_slice = attention_scores.softmax(
|
||||
dim=-1, dtype=attention_scores.dtype
|
||||
)
|
||||
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
||||
attention_slice_wrangler = self.attention_slice_wrangler
|
||||
if attention_slice_wrangler is not None:
|
||||
attention_slice = attention_slice_wrangler(
|
||||
self, default_attention_slice, dim, offset, slice_size
|
||||
)
|
||||
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
|
||||
else:
|
||||
attention_slice = default_attention_slice
|
||||
|
||||
if self.attention_slice_calculated_callback is not None:
|
||||
self.attention_slice_calculated_callback(
|
||||
attention_slice, dim, offset, slice_size
|
||||
)
|
||||
self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size)
|
||||
|
||||
hidden_states = torch.bmm(attention_slice, value)
|
||||
return hidden_states
|
||||
|
||||
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
||||
r = torch.zeros(
|
||||
q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype
|
||||
)
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = i + slice_size
|
||||
r[i:end] = self.einsum_lowest_level(
|
||||
q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size
|
||||
)
|
||||
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_slice_dim1(self, q, k, v, slice_size):
|
||||
r = torch.zeros(
|
||||
q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype
|
||||
)
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
r[:, i:end] = self.einsum_lowest_level(
|
||||
q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size
|
||||
)
|
||||
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_mps_v1(self, q, k, v):
|
||||
@ -353,6 +320,7 @@ def restore_default_cross_attention(
|
||||
else:
|
||||
remove_attention_function(model)
|
||||
|
||||
|
||||
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
|
||||
"""
|
||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||
@ -372,7 +340,7 @@ def setup_cross_attention_control_attention_processors(unet: UNet2DConditionMode
|
||||
indices = torch.arange(max_length, dtype=torch.long)
|
||||
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
||||
if b0 < max_length:
|
||||
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
||||
if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0):
|
||||
# these tokens have not been edited
|
||||
indices[b0:b1] = indices_target[a0:a1]
|
||||
mask[b0:b1] = 1
|
||||
@ -386,16 +354,14 @@ def setup_cross_attention_control_attention_processors(unet: UNet2DConditionMode
|
||||
else:
|
||||
# try to re-use an existing slice size
|
||||
default_slice_size = 4
|
||||
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
|
||||
slice_size = next(
|
||||
(p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size
|
||||
)
|
||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||
|
||||
def get_cross_attention_modules(
|
||||
model, which: CrossAttentionType
|
||||
) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
||||
|
||||
cross_attention_class: type = (
|
||||
InvokeAIDiffusersCrossAttention
|
||||
)
|
||||
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
||||
cross_attention_class: type = InvokeAIDiffusersCrossAttention
|
||||
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
||||
attention_module_tuples = [
|
||||
(name, module)
|
||||
@ -420,9 +386,7 @@ def get_cross_attention_modules(
|
||||
def inject_attention_function(unet, context: Context):
|
||||
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
||||
|
||||
def attention_slice_wrangler(
|
||||
module, suggested_attention_slice: torch.Tensor, dim, offset, slice_size
|
||||
):
|
||||
def attention_slice_wrangler(module, suggested_attention_slice: torch.Tensor, dim, offset, slice_size):
|
||||
# memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement()
|
||||
|
||||
attention_slice = suggested_attention_slice
|
||||
@ -430,9 +394,7 @@ def inject_attention_function(unet, context: Context):
|
||||
if context.get_should_save_maps(module.identifier):
|
||||
# print(module.identifier, "saving suggested_attention_slice of shape",
|
||||
# suggested_attention_slice.shape, "dim", dim, "offset", offset)
|
||||
slice_to_save = (
|
||||
attention_slice.to("cpu") if dim is not None else attention_slice
|
||||
)
|
||||
slice_to_save = attention_slice.to("cpu") if dim is not None else attention_slice
|
||||
context.save_slice(
|
||||
module.identifier,
|
||||
slice_to_save,
|
||||
@ -442,31 +404,20 @@ def inject_attention_function(unet, context: Context):
|
||||
)
|
||||
elif context.get_should_apply_saved_maps(module.identifier):
|
||||
# print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset)
|
||||
saved_attention_slice = context.get_slice(
|
||||
module.identifier, dim, offset, slice_size
|
||||
)
|
||||
saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size)
|
||||
|
||||
# slice may have been offloaded to CPU
|
||||
saved_attention_slice = saved_attention_slice.to(
|
||||
suggested_attention_slice.device
|
||||
)
|
||||
saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device)
|
||||
|
||||
if context.is_tokens_cross_attention(module.identifier):
|
||||
index_map = context.cross_attention_index_map
|
||||
remapped_saved_attention_slice = torch.index_select(
|
||||
saved_attention_slice, -1, index_map
|
||||
)
|
||||
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map)
|
||||
this_attention_slice = suggested_attention_slice
|
||||
|
||||
mask = context.cross_attention_mask.to(
|
||||
torch_dtype(suggested_attention_slice.device)
|
||||
)
|
||||
mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device))
|
||||
saved_mask = mask
|
||||
this_mask = 1 - mask
|
||||
attention_slice = (
|
||||
remapped_saved_attention_slice * saved_mask
|
||||
+ this_attention_slice * this_mask
|
||||
)
|
||||
attention_slice = remapped_saved_attention_slice * saved_mask + this_attention_slice * this_mask
|
||||
else:
|
||||
# just use everything
|
||||
attention_slice = saved_attention_slice
|
||||
@ -480,14 +431,10 @@ def inject_attention_function(unet, context: Context):
|
||||
module.identifier = identifier
|
||||
try:
|
||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||
module.set_slicing_strategy_getter(
|
||||
lambda module: context.get_slicing_strategy(identifier)
|
||||
)
|
||||
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier))
|
||||
except AttributeError as e:
|
||||
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
||||
print(
|
||||
f"TODO: implement set_attention_slice_wrangler for {type(module)}"
|
||||
) # TODO
|
||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
|
||||
else:
|
||||
raise
|
||||
|
||||
@ -503,9 +450,7 @@ def remove_attention_function(unet):
|
||||
module.set_slicing_strategy_getter(None)
|
||||
except AttributeError as e:
|
||||
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
||||
print(
|
||||
f"TODO: implement set_attention_slice_wrangler for {type(module)}"
|
||||
)
|
||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}")
|
||||
else:
|
||||
raise
|
||||
|
||||
@ -530,9 +475,7 @@ def get_mem_free_total(device):
|
||||
return mem_free_total
|
||||
|
||||
|
||||
class InvokeAIDiffusersCrossAttention(
|
||||
diffusers.models.attention.Attention, InvokeAICrossAttentionMixin
|
||||
):
|
||||
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.Attention, InvokeAICrossAttentionMixin):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
InvokeAICrossAttentionMixin.__init__(self)
|
||||
@ -641,11 +584,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
||||
# kwargs
|
||||
swap_cross_attn_context: SwapCrossAttnContext = None,
|
||||
):
|
||||
attention_type = (
|
||||
CrossAttentionType.SELF
|
||||
if encoder_hidden_states is None
|
||||
else CrossAttentionType.TOKENS
|
||||
)
|
||||
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
|
||||
|
||||
# if cross-attention control is not in play, just call through to the base implementation.
|
||||
if (
|
||||
@ -654,9 +593,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
||||
or not swap_cross_attn_context.wants_cross_attention_control(attention_type)
|
||||
):
|
||||
# print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass")
|
||||
return super().__call__(
|
||||
attn, hidden_states, encoder_hidden_states, attention_mask
|
||||
)
|
||||
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask)
|
||||
# else:
|
||||
# print(f"SwapCrossAttnContext for {attention_type} active")
|
||||
|
||||
@ -699,18 +636,10 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
||||
query_slice = query[start_idx:end_idx]
|
||||
original_key_slice = original_text_key[start_idx:end_idx]
|
||||
modified_key_slice = modified_text_key[start_idx:end_idx]
|
||||
attn_mask_slice = (
|
||||
attention_mask[start_idx:end_idx]
|
||||
if attention_mask is not None
|
||||
else None
|
||||
)
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||
|
||||
original_attn_slice = attn.get_attention_scores(
|
||||
query_slice, original_key_slice, attn_mask_slice
|
||||
)
|
||||
modified_attn_slice = attn.get_attention_scores(
|
||||
query_slice, modified_key_slice, attn_mask_slice
|
||||
)
|
||||
original_attn_slice = attn.get_attention_scores(query_slice, original_key_slice, attn_mask_slice)
|
||||
modified_attn_slice = attn.get_attention_scores(query_slice, modified_key_slice, attn_mask_slice)
|
||||
|
||||
# because the prompt modifications may result in token sequences shifted forwards or backwards,
|
||||
# the original attention probabilities must be remapped to account for token index changes in the
|
||||
@ -722,9 +651,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
||||
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
|
||||
mask = swap_cross_attn_context.mask
|
||||
inverse_mask = 1 - mask
|
||||
attn_slice = (
|
||||
remapped_original_attn_slice * mask + modified_attn_slice * inverse_mask
|
||||
)
|
||||
attn_slice = remapped_original_attn_slice * mask + modified_attn_slice * inverse_mask
|
||||
|
||||
del remapped_original_attn_slice, modified_attn_slice
|
||||
|
||||
@ -744,6 +671,4 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
||||
|
||||
class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser):
|
||||
def __init__(self):
|
||||
super(SwapCrossAttnProcessor, self).__init__(
|
||||
slice_size=int(1e9)
|
||||
) # massive slice size = don't slice
|
||||
super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9)) # massive slice size = don't slice
|
||||
|
@ -59,9 +59,7 @@ class AttentionMapSaver:
|
||||
for key, maps in self.collated_maps.items():
|
||||
# maps has shape [(H*W), N] for N tokens
|
||||
# but we want [N, H, W]
|
||||
this_scale_factor = math.sqrt(
|
||||
maps.shape[0] / (latents_width * latents_height)
|
||||
)
|
||||
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
|
||||
this_maps_height = int(float(latents_height) * this_scale_factor)
|
||||
this_maps_width = int(float(latents_width) * this_scale_factor)
|
||||
# and we need to do some dimension juggling
|
||||
@ -72,9 +70,7 @@ class AttentionMapSaver:
|
||||
|
||||
# scale to output size if necessary
|
||||
if this_scale_factor != 1:
|
||||
maps = tv_resize(
|
||||
maps, [latents_height, latents_width], InterpolationMode.BICUBIC
|
||||
)
|
||||
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC)
|
||||
|
||||
# normalize
|
||||
maps_min = torch.min(maps)
|
||||
@ -83,9 +79,7 @@ class AttentionMapSaver:
|
||||
maps_normalized = (maps - maps_min) / maps_range
|
||||
# expand to (-0.1, 1.1) and clamp
|
||||
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
|
||||
maps_normalized_expanded_clamped = torch.clamp(
|
||||
maps_normalized_expanded, 0, 1
|
||||
)
|
||||
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1)
|
||||
|
||||
# merge together, producing a vertical stack
|
||||
maps_stacked = torch.reshape(
|
||||
|
@ -31,6 +31,7 @@ ModelForwardCallback: TypeAlias = Union[
|
||||
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PostprocessingSettings:
|
||||
threshold: float
|
||||
@ -81,14 +82,12 @@ class InvokeAIDiffuserComponent:
|
||||
@contextmanager
|
||||
def custom_attention_context(
|
||||
cls,
|
||||
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
|
||||
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
|
||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||
step_count: int
|
||||
step_count: int,
|
||||
):
|
||||
old_attn_processors = None
|
||||
if extra_conditioning_info and (
|
||||
extra_conditioning_info.wants_cross_attention_control
|
||||
):
|
||||
if extra_conditioning_info and (extra_conditioning_info.wants_cross_attention_control):
|
||||
old_attn_processors = unet.attn_processors
|
||||
# Load lora conditions into the model
|
||||
if extra_conditioning_info.wants_cross_attention_control:
|
||||
@ -116,27 +115,15 @@ class InvokeAIDiffuserComponent:
|
||||
return
|
||||
saver.add_attention_maps(slice, key)
|
||||
|
||||
tokens_cross_attention_modules = get_cross_attention_modules(
|
||||
self.model, CrossAttentionType.TOKENS
|
||||
)
|
||||
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||
for identifier, module in tokens_cross_attention_modules:
|
||||
key = (
|
||||
"down"
|
||||
if identifier.startswith("down")
|
||||
else "up"
|
||||
if identifier.startswith("up")
|
||||
else "mid"
|
||||
)
|
||||
key = "down" if identifier.startswith("down") else "up" if identifier.startswith("up") else "mid"
|
||||
module.set_attention_slice_calculated_callback(
|
||||
lambda slice, dim, offset, slice_size, key=key: callback(
|
||||
slice, dim, offset, slice_size, key
|
||||
)
|
||||
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key)
|
||||
)
|
||||
|
||||
def remove_attention_map_saving(self):
|
||||
tokens_cross_attention_modules = get_cross_attention_modules(
|
||||
self.model, CrossAttentionType.TOKENS
|
||||
)
|
||||
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||
for _, module in tokens_cross_attention_modules:
|
||||
module.set_attention_slice_calculated_callback(None)
|
||||
|
||||
@ -171,10 +158,8 @@ class InvokeAIDiffuserComponent:
|
||||
context: Context = self.cross_attention_control_context
|
||||
if self.cross_attention_control_context is not None:
|
||||
percent_through = step_index / total_step_count
|
||||
cross_attention_control_types_to_do = (
|
||||
context.get_active_cross_attention_control_types_for_step(
|
||||
percent_through
|
||||
)
|
||||
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(
|
||||
percent_through
|
||||
)
|
||||
|
||||
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
||||
@ -182,7 +167,11 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
if wants_hybrid_conditioning:
|
||||
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
|
||||
x, sigma, unconditioning, conditioning, **kwargs,
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
**kwargs,
|
||||
)
|
||||
elif wants_cross_attention_control:
|
||||
(
|
||||
@ -201,7 +190,11 @@ class InvokeAIDiffuserComponent:
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_standard_conditioning_sequentially(
|
||||
x, sigma, unconditioning, conditioning, **kwargs,
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
@ -209,12 +202,18 @@ class InvokeAIDiffuserComponent:
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_standard_conditioning(
|
||||
x, sigma, unconditioning, conditioning, **kwargs,
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
combined_next_x = self._combine(
|
||||
# unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
|
||||
unconditioned_next_x, conditioned_next_x, guidance_scale
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
guidance_scale,
|
||||
)
|
||||
|
||||
return combined_next_x
|
||||
@ -229,37 +228,47 @@ class InvokeAIDiffuserComponent:
|
||||
) -> torch.Tensor:
|
||||
if postprocessing_settings is not None:
|
||||
percent_through = step_index / total_step_count
|
||||
latents = self.apply_threshold(
|
||||
postprocessing_settings, latents, percent_through
|
||||
)
|
||||
latents = self.apply_symmetry(
|
||||
postprocessing_settings, latents, percent_through
|
||||
)
|
||||
latents = self.apply_threshold(postprocessing_settings, latents, percent_through)
|
||||
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
|
||||
return latents
|
||||
|
||||
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
|
||||
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
||||
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
||||
conditioning_attention_mask = torch.ones(
|
||||
(cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype
|
||||
)
|
||||
|
||||
if cond.shape[1] < max_len:
|
||||
conditioning_attention_mask = torch.cat([
|
||||
conditioning_attention_mask,
|
||||
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
|
||||
], dim=1)
|
||||
conditioning_attention_mask = torch.cat(
|
||||
[
|
||||
conditioning_attention_mask,
|
||||
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
cond = torch.cat([
|
||||
cond,
|
||||
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
|
||||
], dim=1)
|
||||
cond = torch.cat(
|
||||
[
|
||||
cond,
|
||||
torch.zeros(
|
||||
(cond.shape[0], max_len - cond.shape[1], cond.shape[2]),
|
||||
device=cond.device,
|
||||
dtype=cond.dtype,
|
||||
),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = conditioning_attention_mask
|
||||
else:
|
||||
encoder_attention_mask = torch.cat([
|
||||
encoder_attention_mask,
|
||||
conditioning_attention_mask,
|
||||
])
|
||||
|
||||
encoder_attention_mask = torch.cat(
|
||||
[
|
||||
encoder_attention_mask,
|
||||
conditioning_attention_mask,
|
||||
]
|
||||
)
|
||||
|
||||
return cond, encoder_attention_mask
|
||||
|
||||
encoder_attention_mask = None
|
||||
@ -277,11 +286,11 @@ class InvokeAIDiffuserComponent:
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
|
||||
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||
unconditioning, conditioning
|
||||
)
|
||||
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(unconditioning, conditioning)
|
||||
both_results = self.model_forward_callback(
|
||||
x_twice, sigma_twice, both_conditionings,
|
||||
x_twice,
|
||||
sigma_twice,
|
||||
both_conditionings,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
@ -312,13 +321,17 @@ class InvokeAIDiffuserComponent:
|
||||
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||
|
||||
unconditioned_next_x = self.model_forward_callback(
|
||||
x, sigma, unconditioning,
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
down_block_additional_residuals=uncond_down_block,
|
||||
mid_block_additional_residual=uncond_mid_block,
|
||||
**kwargs,
|
||||
)
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
x, sigma, conditioning,
|
||||
x,
|
||||
sigma,
|
||||
conditioning,
|
||||
down_block_additional_residuals=cond_down_block,
|
||||
mid_block_additional_residual=cond_mid_block,
|
||||
**kwargs,
|
||||
@ -335,13 +348,15 @@ class InvokeAIDiffuserComponent:
|
||||
for k in conditioning:
|
||||
if isinstance(conditioning[k], list):
|
||||
both_conditionings[k] = [
|
||||
torch.cat([unconditioning[k][i], conditioning[k][i]])
|
||||
for i in range(len(conditioning[k]))
|
||||
torch.cat([unconditioning[k][i], conditioning[k][i]]) for i in range(len(conditioning[k]))
|
||||
]
|
||||
else:
|
||||
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
|
||||
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
|
||||
x_twice, sigma_twice, both_conditionings, **kwargs,
|
||||
x_twice,
|
||||
sigma_twice,
|
||||
both_conditionings,
|
||||
**kwargs,
|
||||
).chunk(2)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
@ -388,9 +403,7 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
|
||||
# do requested cross attention types for conditioning (positive prompt)
|
||||
cross_attn_processor_context.cross_attention_types_to_do = (
|
||||
cross_attention_control_types_to_do
|
||||
)
|
||||
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
sigma,
|
||||
@ -414,19 +427,14 @@ class InvokeAIDiffuserComponent:
|
||||
latents: torch.Tensor,
|
||||
percent_through: float,
|
||||
) -> torch.Tensor:
|
||||
if (
|
||||
postprocessing_settings.threshold is None
|
||||
or postprocessing_settings.threshold == 0.0
|
||||
):
|
||||
if postprocessing_settings.threshold is None or postprocessing_settings.threshold == 0.0:
|
||||
return latents
|
||||
|
||||
threshold = postprocessing_settings.threshold
|
||||
warmup = postprocessing_settings.warmup
|
||||
|
||||
if percent_through < warmup:
|
||||
current_threshold = threshold + threshold * 5 * (
|
||||
1 - (percent_through / warmup)
|
||||
)
|
||||
current_threshold = threshold + threshold * 5 * (1 - (percent_through / warmup))
|
||||
else:
|
||||
current_threshold = threshold
|
||||
|
||||
@ -440,18 +448,10 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
if self.debug_thresholding:
|
||||
std, mean = [i.item() for i in torch.std_mean(latents)]
|
||||
outside = torch.count_nonzero(
|
||||
(latents < -current_threshold) | (latents > current_threshold)
|
||||
)
|
||||
logger.info(
|
||||
f"Threshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})"
|
||||
)
|
||||
logger.debug(
|
||||
f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}"
|
||||
)
|
||||
logger.debug(
|
||||
f"{outside / latents.numel() * 100:.2f}% values outside threshold"
|
||||
)
|
||||
outside = torch.count_nonzero((latents < -current_threshold) | (latents > current_threshold))
|
||||
logger.info(f"Threshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})")
|
||||
logger.debug(f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}")
|
||||
logger.debug(f"{outside / latents.numel() * 100:.2f}% values outside threshold")
|
||||
|
||||
if maxval < current_threshold and minval > -current_threshold:
|
||||
return latents
|
||||
@ -464,25 +464,17 @@ class InvokeAIDiffuserComponent:
|
||||
latents = torch.clone(latents)
|
||||
maxval = np.clip(maxval * scale, 1, current_threshold)
|
||||
num_altered += torch.count_nonzero(latents > maxval)
|
||||
latents[latents > maxval] = (
|
||||
torch.rand_like(latents[latents > maxval]) * maxval
|
||||
)
|
||||
latents[latents > maxval] = torch.rand_like(latents[latents > maxval]) * maxval
|
||||
|
||||
if minval < -current_threshold:
|
||||
latents = torch.clone(latents)
|
||||
minval = np.clip(minval * scale, -current_threshold, -1)
|
||||
num_altered += torch.count_nonzero(latents < minval)
|
||||
latents[latents < minval] = (
|
||||
torch.rand_like(latents[latents < minval]) * minval
|
||||
)
|
||||
latents[latents < minval] = torch.rand_like(latents[latents < minval]) * minval
|
||||
|
||||
if self.debug_thresholding:
|
||||
logger.debug(
|
||||
f"min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})"
|
||||
)
|
||||
logger.debug(
|
||||
f"{num_altered / latents.numel() * 100:.2f}% values altered"
|
||||
)
|
||||
logger.debug(f"min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})")
|
||||
logger.debug(f"{num_altered / latents.numel() * 100:.2f}% values altered")
|
||||
|
||||
return latents
|
||||
|
||||
@ -501,15 +493,11 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
# Check for out of bounds
|
||||
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
|
||||
if h_symmetry_time_pct is not None and (
|
||||
h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0
|
||||
):
|
||||
if h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0):
|
||||
h_symmetry_time_pct = None
|
||||
|
||||
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
|
||||
if v_symmetry_time_pct is not None and (
|
||||
v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0
|
||||
):
|
||||
if v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0):
|
||||
v_symmetry_time_pct = None
|
||||
|
||||
dev = latents.device.type
|
||||
@ -554,9 +542,7 @@ class InvokeAIDiffuserComponent:
|
||||
def estimate_percent_through(self, step_index, sigma):
|
||||
if step_index is not None and self.cross_attention_control_context is not None:
|
||||
# percent_through will never reach 1.0 (but this is intended)
|
||||
return float(step_index) / float(
|
||||
self.cross_attention_control_context.step_count
|
||||
)
|
||||
return float(step_index) / float(self.cross_attention_control_context.step_count)
|
||||
# find the best possible index of the current sigma in the sigma sequence
|
||||
smaller_sigmas = torch.nonzero(self.model.sigmas <= sigma)
|
||||
sigma_index = smaller_sigmas[-1].item() if smaller_sigmas.shape[0] > 0 else 0
|
||||
@ -567,19 +553,13 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
# todo: make this work
|
||||
@classmethod
|
||||
def apply_conjunction(
|
||||
cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale
|
||||
):
|
||||
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2) # aka sigmas
|
||||
|
||||
deltas = None
|
||||
uncond_latents = None
|
||||
weighted_cond_list = (
|
||||
c_or_weighted_c_list
|
||||
if type(c_or_weighted_c_list) is list
|
||||
else [(c_or_weighted_c_list, 1)]
|
||||
)
|
||||
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
|
||||
|
||||
# below is fugly omg
|
||||
conditionings = [uc] + [c for c, weight in weighted_cond_list]
|
||||
@ -608,15 +588,11 @@ class InvokeAIDiffuserComponent:
|
||||
deltas = torch.cat((deltas, latents_b - uncond_latents))
|
||||
|
||||
# merge the weighted deltas together into a single merged delta
|
||||
per_delta_weights = torch.tensor(
|
||||
weights[1:], dtype=deltas.dtype, device=deltas.device
|
||||
)
|
||||
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
|
||||
normalize = False
|
||||
if normalize:
|
||||
per_delta_weights /= torch.sum(per_delta_weights)
|
||||
reshaped_weights = per_delta_weights.reshape(
|
||||
per_delta_weights.shape + (1, 1, 1)
|
||||
)
|
||||
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
|
||||
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
|
||||
|
||||
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
|
||||
|
@ -261,9 +261,7 @@ def srmd_degradation(x, k, sf=3):
|
||||
year={2018}
|
||||
}
|
||||
"""
|
||||
x = ndimage.filters.convolve(
|
||||
x, np.expand_dims(k, axis=2), mode="wrap"
|
||||
) # 'nearest' | 'mirror'
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # 'nearest' | 'mirror'
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
return x
|
||||
|
||||
@ -389,21 +387,15 @@ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
||||
noise_level = random.randint(noise_level1, noise_level2)
|
||||
rnum = np.random.rand()
|
||||
if rnum > 0.6: # add color Gaussian noise
|
||||
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(
|
||||
np.float32
|
||||
)
|
||||
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
elif rnum < 0.4: # add grayscale Gaussian noise
|
||||
img = img + np.random.normal(
|
||||
0, noise_level / 255.0, (*img.shape[:2], 1)
|
||||
).astype(np.float32)
|
||||
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
else: # add noise
|
||||
L = noise_level2 / 255.0
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3, 3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img = img + np.random.multivariate_normal(
|
||||
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
|
||||
).astype(np.float32)
|
||||
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
@ -413,21 +405,15 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
rnum = random.random()
|
||||
if rnum > 0.6:
|
||||
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(
|
||||
np.float32
|
||||
)
|
||||
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
elif rnum < 0.4:
|
||||
img += img * np.random.normal(
|
||||
0, noise_level / 255.0, (*img.shape[:2], 1)
|
||||
).astype(np.float32)
|
||||
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
else:
|
||||
L = noise_level2 / 255.0
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3, 3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img += img * np.random.multivariate_normal(
|
||||
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
|
||||
).astype(np.float32)
|
||||
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
@ -440,9 +426,7 @@ def add_Poisson_noise(img):
|
||||
else:
|
||||
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
|
||||
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
|
||||
noise_gray = (
|
||||
np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
||||
)
|
||||
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
||||
img += noise_gray[:, :, np.newaxis]
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
@ -451,9 +435,7 @@ def add_Poisson_noise(img):
|
||||
def add_JPEG_noise(img):
|
||||
quality_factor = random.randint(30, 95)
|
||||
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
|
||||
result, encimg = cv2.imencode(
|
||||
".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]
|
||||
)
|
||||
result, encimg = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
|
||||
img = cv2.imdecode(encimg, 1)
|
||||
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
@ -540,9 +522,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
|
||||
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||
img = ndimage.filters.convolve(
|
||||
img, np.expand_dims(k_shifted, axis=2), mode="mirror"
|
||||
)
|
||||
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode="mirror")
|
||||
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
@ -646,9 +626,7 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
||||
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||
image = ndimage.filters.convolve(
|
||||
image, np.expand_dims(k_shifted, axis=2), mode="mirror"
|
||||
)
|
||||
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode="mirror")
|
||||
image = image[0::sf, 0::sf, ...] # nearest downsampling
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
|
||||
@ -796,9 +774,7 @@ if __name__ == "__main__":
|
||||
print(i)
|
||||
img_lq = deg_fn(img)
|
||||
print(img_lq)
|
||||
img_lq_bicubic = albumentations.SmallestMaxSize(
|
||||
max_size=h, interpolation=cv2.INTER_CUBIC
|
||||
)(image=img)["image"]
|
||||
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
|
||||
print(img_lq.shape)
|
||||
print("bicubic", img_lq_bicubic.shape)
|
||||
print(img_hq.shape)
|
||||
@ -812,7 +788,5 @@ if __name__ == "__main__":
|
||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||
interpolation=0,
|
||||
)
|
||||
img_concat = np.concatenate(
|
||||
[lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1
|
||||
)
|
||||
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
|
||||
util.imsave(img_concat, str(i) + ".png")
|
||||
|
@ -261,9 +261,7 @@ def srmd_degradation(x, k, sf=3):
|
||||
year={2018}
|
||||
}
|
||||
"""
|
||||
x = ndimage.filters.convolve(
|
||||
x, np.expand_dims(k, axis=2), mode="wrap"
|
||||
) # 'nearest' | 'mirror'
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # 'nearest' | 'mirror'
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
return x
|
||||
|
||||
@ -393,21 +391,15 @@ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
||||
noise_level = random.randint(noise_level1, noise_level2)
|
||||
rnum = np.random.rand()
|
||||
if rnum > 0.6: # add color Gaussian noise
|
||||
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(
|
||||
np.float32
|
||||
)
|
||||
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
elif rnum < 0.4: # add grayscale Gaussian noise
|
||||
img = img + np.random.normal(
|
||||
0, noise_level / 255.0, (*img.shape[:2], 1)
|
||||
).astype(np.float32)
|
||||
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
else: # add noise
|
||||
L = noise_level2 / 255.0
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3, 3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img = img + np.random.multivariate_normal(
|
||||
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
|
||||
).astype(np.float32)
|
||||
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
@ -417,21 +409,15 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
rnum = random.random()
|
||||
if rnum > 0.6:
|
||||
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(
|
||||
np.float32
|
||||
)
|
||||
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
elif rnum < 0.4:
|
||||
img += img * np.random.normal(
|
||||
0, noise_level / 255.0, (*img.shape[:2], 1)
|
||||
).astype(np.float32)
|
||||
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
else:
|
||||
L = noise_level2 / 255.0
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3, 3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img += img * np.random.multivariate_normal(
|
||||
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
|
||||
).astype(np.float32)
|
||||
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
@ -444,9 +430,7 @@ def add_Poisson_noise(img):
|
||||
else:
|
||||
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
|
||||
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
|
||||
noise_gray = (
|
||||
np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
||||
)
|
||||
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
||||
img += noise_gray[:, :, np.newaxis]
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
@ -455,9 +439,7 @@ def add_Poisson_noise(img):
|
||||
def add_JPEG_noise(img):
|
||||
quality_factor = random.randint(80, 95)
|
||||
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
|
||||
result, encimg = cv2.imencode(
|
||||
".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]
|
||||
)
|
||||
result, encimg = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
|
||||
img = cv2.imdecode(encimg, 1)
|
||||
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
@ -544,9 +526,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
|
||||
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||
img = ndimage.filters.convolve(
|
||||
img, np.expand_dims(k_shifted, axis=2), mode="mirror"
|
||||
)
|
||||
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode="mirror")
|
||||
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
@ -653,9 +633,7 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
||||
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||
image = ndimage.filters.convolve(
|
||||
image, np.expand_dims(k_shifted, axis=2), mode="mirror"
|
||||
)
|
||||
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode="mirror")
|
||||
image = image[0::sf, 0::sf, ...] # nearest downsampling
|
||||
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
@ -705,9 +683,9 @@ if __name__ == "__main__":
|
||||
img_lq = deg_fn(img)["image"]
|
||||
img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
|
||||
print(img_lq)
|
||||
img_lq_bicubic = albumentations.SmallestMaxSize(
|
||||
max_size=h, interpolation=cv2.INTER_CUBIC
|
||||
)(image=img_hq)["image"]
|
||||
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)[
|
||||
"image"
|
||||
]
|
||||
print(img_lq.shape)
|
||||
print("bicubic", img_lq_bicubic.shape)
|
||||
print(img_hq.shape)
|
||||
@ -721,7 +699,5 @@ if __name__ == "__main__":
|
||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||
interpolation=0,
|
||||
)
|
||||
img_concat = np.concatenate(
|
||||
[lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1
|
||||
)
|
||||
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
|
||||
util.imsave(img_concat, str(i) + ".png")
|
||||
|
@ -11,6 +11,7 @@ from torchvision.utils import make_grid
|
||||
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
|
||||
|
||||
@ -296,22 +297,14 @@ def single2uint16(img):
|
||||
def uint2tensor4(img):
|
||||
if img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
return (
|
||||
torch.from_numpy(np.ascontiguousarray(img))
|
||||
.permute(2, 0, 1)
|
||||
.float()
|
||||
.div(255.0)
|
||||
.unsqueeze(0)
|
||||
)
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0).unsqueeze(0)
|
||||
|
||||
|
||||
# convert uint to 3-dimensional torch tensor
|
||||
def uint2tensor3(img):
|
||||
if img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
return (
|
||||
torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0)
|
||||
)
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0)
|
||||
|
||||
|
||||
# convert 2/3/4-dimensional torch tensor to uint
|
||||
@ -334,12 +327,7 @@ def single2tensor3(img):
|
||||
|
||||
# convert single (HxWxC) to 4-dimensional torch tensor
|
||||
def single2tensor4(img):
|
||||
return (
|
||||
torch.from_numpy(np.ascontiguousarray(img))
|
||||
.permute(2, 0, 1)
|
||||
.float()
|
||||
.unsqueeze(0)
|
||||
)
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
|
||||
|
||||
|
||||
# convert torch tensor to single
|
||||
@ -362,12 +350,7 @@ def tensor2single3(img):
|
||||
|
||||
|
||||
def single2tensor5(img):
|
||||
return (
|
||||
torch.from_numpy(np.ascontiguousarray(img))
|
||||
.permute(2, 0, 1, 3)
|
||||
.float()
|
||||
.unsqueeze(0)
|
||||
)
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
|
||||
|
||||
|
||||
def single32tensor5(img):
|
||||
@ -385,9 +368,7 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
|
||||
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
|
||||
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
|
||||
"""
|
||||
tensor = (
|
||||
tensor.squeeze().float().cpu().clamp_(*min_max)
|
||||
) # squeeze first, then clamp
|
||||
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
|
||||
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
|
||||
n_dim = tensor.dim()
|
||||
if n_dim == 4:
|
||||
@ -400,11 +381,7 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
|
||||
elif n_dim == 2:
|
||||
img_np = tensor.numpy()
|
||||
else:
|
||||
raise TypeError(
|
||||
"Only support 4D, 3D and 2D tensor. But received with dimension: {:d}".format(
|
||||
n_dim
|
||||
)
|
||||
)
|
||||
raise TypeError("Only support 4D, 3D and 2D tensor. But received with dimension: {:d}".format(n_dim))
|
||||
if out_type == np.uint8:
|
||||
img_np = (img_np * 255.0).round()
|
||||
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
|
||||
@ -744,9 +721,7 @@ def ssim(img1, img2):
|
||||
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
||||
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
|
||||
(mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
|
||||
)
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
||||
return ssim_map.mean()
|
||||
|
||||
|
||||
@ -767,9 +742,7 @@ def cubic(x):
|
||||
) * (((absx > 1) * (absx <= 2)).type_as(absx))
|
||||
|
||||
|
||||
def calculate_weights_indices(
|
||||
in_length, out_length, scale, kernel, kernel_width, antialiasing
|
||||
):
|
||||
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
|
||||
if (scale < 1) and (antialiasing):
|
||||
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
|
||||
kernel_width = kernel_width / scale
|
||||
@ -793,9 +766,9 @@ def calculate_weights_indices(
|
||||
|
||||
# The indices of the input pixels involved in computing the k-th output
|
||||
# pixel are in row k of the indices matrix.
|
||||
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(
|
||||
0, P - 1, P
|
||||
).view(1, P).expand(out_length, P)
|
||||
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(1, P).expand(
|
||||
out_length, P
|
||||
)
|
||||
|
||||
# The weights used to compute the k-th output pixel are in row k of the
|
||||
# weights matrix.
|
||||
@ -876,9 +849,7 @@ def imresize(img, scale, antialiasing=True):
|
||||
for i in range(out_H):
|
||||
idx = int(indices_H[i][0])
|
||||
for j in range(out_C):
|
||||
out_1[j, i, :] = (
|
||||
img_aug[j, idx : idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
||||
)
|
||||
out_1[j, i, :] = img_aug[j, idx : idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
||||
|
||||
# process W dimension
|
||||
# symmetric copying
|
||||
@ -959,9 +930,7 @@ def imresize_np(img, scale, antialiasing=True):
|
||||
for i in range(out_H):
|
||||
idx = int(indices_H[i][0])
|
||||
for j in range(out_C):
|
||||
out_1[i, :, j] = (
|
||||
img_aug[idx : idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
|
||||
)
|
||||
out_1[i, :, j] = img_aug[idx : idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
|
||||
|
||||
# process W dimension
|
||||
# symmetric copying
|
||||
|
@ -95,10 +95,7 @@ class ModelGroup(metaclass=ABCMeta):
|
||||
pass
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<{self.__class__.__name__} object at {id(self):x}: "
|
||||
f"device={self.execution_device} >"
|
||||
)
|
||||
return f"<{self.__class__.__name__} object at {id(self):x}: " f"device={self.execution_device} >"
|
||||
|
||||
|
||||
class LazilyLoadedModelGroup(ModelGroup):
|
||||
@ -143,8 +140,7 @@ class LazilyLoadedModelGroup(ModelGroup):
|
||||
self.load(module)
|
||||
if len(forward_input) == 0:
|
||||
warnings.warn(
|
||||
f"Hook for {module.__class__.__name__} got no input. "
|
||||
f"Inputs must be positional, not keywords.",
|
||||
f"Hook for {module.__class__.__name__} got no input. " f"Inputs must be positional, not keywords.",
|
||||
stacklevel=3,
|
||||
)
|
||||
return send_to_device(forward_input, self.execution_device)
|
||||
@ -161,9 +157,7 @@ class LazilyLoadedModelGroup(ModelGroup):
|
||||
self.clear_current_model()
|
||||
|
||||
def _load(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
assert (
|
||||
self.is_empty()
|
||||
), f"A model is already loaded: {self._current_model_ref()}"
|
||||
assert self.is_empty(), f"A model is already loaded: {self._current_model_ref()}"
|
||||
module = module.to(self.execution_device)
|
||||
self.set_current_model(module)
|
||||
return module
|
||||
@ -192,12 +186,8 @@ class LazilyLoadedModelGroup(ModelGroup):
|
||||
|
||||
def device_for(self, model):
|
||||
if model not in self:
|
||||
raise KeyError(
|
||||
f"This does not manage this model {type(model).__name__}", model
|
||||
)
|
||||
return (
|
||||
self.execution_device
|
||||
) # this implementation only dispatches to one device
|
||||
raise KeyError(f"This does not manage this model {type(model).__name__}", model)
|
||||
return self.execution_device # this implementation only dispatches to one device
|
||||
|
||||
def ready(self):
|
||||
pass # always ready to load on-demand
|
||||
@ -256,12 +246,8 @@ class FullyLoadedModelGroup(ModelGroup):
|
||||
|
||||
def device_for(self, model):
|
||||
if model not in self:
|
||||
raise KeyError(
|
||||
"This does not manage this model f{type(model).__name__}", model
|
||||
)
|
||||
return (
|
||||
self.execution_device
|
||||
) # this implementation only dispatches to one device
|
||||
raise KeyError("This does not manage this model f{type(model).__name__}", model)
|
||||
return self.execution_device # this implementation only dispatches to one device
|
||||
|
||||
def __contains__(self, model):
|
||||
return model in self._models
|
||||
|
@ -1 +1 @@
|
||||
from .schedulers import SCHEDULER_MAP
|
||||
from .schedulers import SCHEDULER_MAP
|
||||
|
@ -1,7 +1,19 @@
|
||||
from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, \
|
||||
KDPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, \
|
||||
HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UniPCMultistepScheduler, \
|
||||
DPMSolverSinglestepScheduler, DEISMultistepScheduler, DDPMScheduler, DPMSolverSDEScheduler
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
UniPCMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSDEScheduler,
|
||||
)
|
||||
|
||||
SCHEDULER_MAP = dict(
|
||||
ddim=(DDIMScheduler, dict()),
|
||||
@ -21,9 +33,9 @@ SCHEDULER_MAP = dict(
|
||||
dpmpp_2s_k=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)),
|
||||
dpmpp_2m=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False)),
|
||||
dpmpp_2m_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)),
|
||||
dpmpp_2m_sde=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False, algorithm_type='sde-dpmsolver++')),
|
||||
dpmpp_2m_sde_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True, algorithm_type='sde-dpmsolver++')),
|
||||
dpmpp_2m_sde=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False, algorithm_type="sde-dpmsolver++")),
|
||||
dpmpp_2m_sde_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True, algorithm_type="sde-dpmsolver++")),
|
||||
dpmpp_sde=(DPMSolverSDEScheduler, dict(use_karras_sigmas=False, noise_sampler_seed=0)),
|
||||
dpmpp_sde_k=(DPMSolverSDEScheduler, dict(use_karras_sigmas=True, noise_sampler_seed=0)),
|
||||
unipc=(UniPCMultistepScheduler, dict(cpu_only=True))
|
||||
unipc=(UniPCMultistepScheduler, dict(cpu_only=True)),
|
||||
)
|
||||
|
Reference in New Issue
Block a user