Removed diffusers_pipeline prepare_control_image() -- replaced with controlnet_utils.prepare_control_image()

Added resize_mode to ControlNetData class.
This commit is contained in:
user1 2023-07-19 19:21:17 -07:00
parent 6affe42310
commit e918168f7a

View File

@ -219,6 +219,7 @@ class ControlNetData:
begin_step_percent: float = Field(default=0.0) begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0) end_step_percent: float = Field(default=1.0)
control_mode: str = Field(default="balanced") control_mode: str = Field(default="balanced")
resize_mode: str = Field(default="just_resize")
@dataclass @dataclass
@ -653,7 +654,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if cfg_injection: if cfg_injection:
# Inferred ControlNet only for the conditional batch. # Inferred ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches, # To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged. # prepend zeros for unconditional batch
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples] down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample]) mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
@ -954,53 +955,3 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
debug_image( debug_image(
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
) )
# Copied from diffusers pipeline_stable_diffusion_controlnet.py
# Returns torch.Tensor of shape (batch_size, 3, height, width)
@staticmethod
def prepare_control_image(
image,
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
# latents,
width=512, # should be 8 * latent.shape[3]
height=512, # should be 8 * latent height[2]
batch_size=1,
num_images_per_prompt=1,
device="cuda",
dtype=torch.float16,
do_classifier_free_guidance=True,
control_mode="balanced"
):
if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
images = []
for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)
image = images
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
if do_classifier_free_guidance and not cfg_injection:
image = torch.cat([image] * 2)
return image