Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled.

This commit is contained in:
user1
2023-05-12 01:43:47 -07:00
committed by Kent Keirsey
parent 48485fe92f
commit 63d248622c
2 changed files with 65 additions and 48 deletions

View File

@ -2,10 +2,12 @@ from __future__ import annotations
import dataclasses
import inspect
import math
import secrets
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
from pydantic import BaseModel, Field
import einops
import PIL.Image
@ -212,6 +214,12 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
raise AssertionError("why was that an empty generator?")
return result
@dataclass
class ControlNetData:
model: ControlNetModel = Field(default=None)
image_tensor: torch.Tensor= Field(default=None)
weight: float = Field(default=1.0)
@dataclass(frozen=True)
class ConditioningData:
@ -518,6 +526,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance: List[Callable] = None,
run_id=None,
callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if self.scheduler.config.get("cpu_only", False):
@ -539,6 +548,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance=additional_guidance,
run_id=run_id,
callback=callback,
control_data=control_data,
**kwargs,
)
return result.latents, result.attention_map_saver
@ -552,6 +562,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: torch.Tensor,
run_id: str = None,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
**kwargs,
):
self._adjust_memory_efficient_attention(latents)
@ -582,7 +593,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents = self.scheduler.add_noise(latents, noise, batched_t)
attention_map_saver: Optional[AttentionMapSaver] = None
# print("timesteps:", timesteps)
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t)
step_output = self.step(
@ -592,6 +603,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index=i,
total_step_count=len(timesteps),
additional_guidance=additional_guidance,
control_data=control_data,
**kwargs,
)
latents = step_output.prev_sample
@ -633,11 +645,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index: int,
total_step_count: int,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
**kwargs,
):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0]
if additional_guidance is None:
additional_guidance = []
@ -645,13 +657,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# i.e. before or after passing it to InvokeAIDiffuserComponent
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
if (self.control_model is not None) and (kwargs.get("control_image") is not None):
control_image = kwargs.get("control_image") # should be a processed tensor derived from the control image(s)
control_weight = kwargs.get("control_weight", 1.0) # control_weight default is 1.0
# handling case where using multiple control models but only specifying single control_weight
# so reshape control_weight to match number of control models
if isinstance(self.control_model, MultiControlNetModel) and isinstance(control_weight, float):
control_weight = [control_weight] * len(self.control_model.nets)
# if (self.control_model is not None) and (control_image is not None):
if control_data is not None:
if conditioning_data.guidance_scale > 1.0:
# expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if
@ -659,16 +666,31 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latent_control_input = torch.cat([latent_model_input] * 2)
else:
latent_control_input = latent_model_input
# controlnet inference
down_block_res_samples, mid_block_res_sample = self.control_model(
latent_control_input,
timestep,
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings]),
controlnet_cond=control_image,
conditioning_scale=control_weight,
return_dict=False,
)
# control_data should be type List[ControlNetData]
# this loop covers both ControlNet (1 ControlNetData in list)
# and MultiControlNet (multiple ControlNetData in list)
for i, control_datum in enumerate(control_data):
# print("controlnet", i, "==>", type(control_datum))
down_samples, mid_sample = control_datum.model(
sample=latent_control_input,
timestep=timestep,
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings]),
controlnet_cond=control_datum.image_tensor,
conditioning_scale=control_datum.weight,
# cross_attention_kwargs,
guess_mode=False,
return_dict=False,
)
if i == 0:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
# add controlnet outputs together if have multiple controlnets
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
]
mid_block_res_sample += mid_sample
else:
down_block_res_samples, mid_block_res_sample = None, None