mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled.
This commit is contained in:
parent
48485fe92f
commit
63d248622c
@ -20,8 +20,11 @@ from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.image_util.seamless import configure_model_padding
|
||||
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
import numpy as np
|
||||
from ..services.image_file_storage import ImageType
|
||||
@ -260,9 +263,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(context, model)
|
||||
|
||||
print("type of control input: ", type(self.control))
|
||||
|
||||
# print("type of control input: ", type(self.control))
|
||||
if self.control is None:
|
||||
print("control input is None")
|
||||
control_list = None
|
||||
@ -271,14 +272,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
control_list = None
|
||||
elif isinstance(self.control, ControlField):
|
||||
print("control input is ControlField")
|
||||
# control = [self.control]
|
||||
control_list = [self.control]
|
||||
# elif isinstance(self.control, list) and len(self.control)>0 and isinstance(self.control[0], ControlField):
|
||||
elif isinstance(self.control, list) and len(self.control) > 0 and isinstance(self.control[0], ControlField):
|
||||
print("control input is list[ControlField]")
|
||||
# print("using first controlnet in list")
|
||||
control_list = self.control
|
||||
# control = self.control
|
||||
else:
|
||||
print("input control is unrecognized:", type(self.control))
|
||||
control_list = None
|
||||
@ -286,25 +283,18 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
#if (self.control is None or (isinstance(self.control, list) and len(self.control)==0)):
|
||||
if (control_list is None):
|
||||
control_models = None
|
||||
control_weights = None
|
||||
control_images = None
|
||||
# from above handling, any control that is not None should now be of type list[ControlField]
|
||||
else:
|
||||
# FIXME: add checks to skip entry if model or image is None
|
||||
# and if weight is None, populate with default 1.0?
|
||||
control_data = []
|
||||
control_models = []
|
||||
control_images = []
|
||||
control_weights = []
|
||||
for control_info in control_list:
|
||||
# handle control weights
|
||||
control_weights.append(control_info.control_weight)
|
||||
|
||||
# handle control models
|
||||
# FIXME: change this to dropdown menu
|
||||
control_model = ControlNetModel.from_pretrained(control_info.control_model,
|
||||
torch_dtype=model.unet.dtype).to(model.device)
|
||||
control_models.append(control_model)
|
||||
|
||||
# handle control images
|
||||
# loading controlnet image (currently requires pre-processed image)
|
||||
# control_image = prep_control_image(control_info.image)
|
||||
@ -313,6 +303,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
# and do real check for classifier_free_guidance?
|
||||
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
|
||||
control_image = model.prepare_control_image(
|
||||
image=input_image,
|
||||
# do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
@ -324,9 +315,14 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
device=control_model.device,
|
||||
dtype=control_model.dtype,
|
||||
)
|
||||
control_images.append(control_image)
|
||||
multi_control = MultiControlNetModel(control_models)
|
||||
model.control_model = multi_control
|
||||
control_item = ControlNetData(model=control_model,
|
||||
image_tensor=control_image,
|
||||
weight=control_info.control_weight)
|
||||
control_data.append(control_item)
|
||||
# multi_control = MultiControlNetModel(control_models)
|
||||
# model.control_model = multi_control
|
||||
# model.control_model = control_models
|
||||
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||
@ -335,8 +331,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
callback=step_callback,
|
||||
control_image=control_images,
|
||||
control_weight=control_weights,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
|
@ -2,10 +2,12 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
import math
|
||||
import secrets
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import einops
|
||||
import PIL.Image
|
||||
@ -212,6 +214,12 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
||||
raise AssertionError("why was that an empty generator?")
|
||||
return result
|
||||
|
||||
@dataclass
|
||||
class ControlNetData:
|
||||
model: ControlNetModel = Field(default=None)
|
||||
image_tensor: torch.Tensor= Field(default=None)
|
||||
weight: float = Field(default=1.0)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConditioningData:
|
||||
@ -518,6 +526,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
additional_guidance: List[Callable] = None,
|
||||
run_id=None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||
if self.scheduler.config.get("cpu_only", False):
|
||||
@ -539,6 +548,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
additional_guidance=additional_guidance,
|
||||
run_id=run_id,
|
||||
callback=callback,
|
||||
control_data=control_data,
|
||||
**kwargs,
|
||||
)
|
||||
return result.latents, result.attention_map_saver
|
||||
@ -552,6 +562,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
noise: torch.Tensor,
|
||||
run_id: str = None,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
@ -582,7 +593,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
|
||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||
|
||||
# print("timesteps:", timesteps)
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
batched_t.fill_(t)
|
||||
step_output = self.step(
|
||||
@ -592,6 +603,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps),
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
**kwargs,
|
||||
)
|
||||
latents = step_output.prev_sample
|
||||
@ -633,11 +645,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||
timestep = t[0]
|
||||
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
|
||||
@ -645,13 +657,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
||||
|
||||
if (self.control_model is not None) and (kwargs.get("control_image") is not None):
|
||||
control_image = kwargs.get("control_image") # should be a processed tensor derived from the control image(s)
|
||||
control_weight = kwargs.get("control_weight", 1.0) # control_weight default is 1.0
|
||||
# handling case where using multiple control models but only specifying single control_weight
|
||||
# so reshape control_weight to match number of control models
|
||||
if isinstance(self.control_model, MultiControlNetModel) and isinstance(control_weight, float):
|
||||
control_weight = [control_weight] * len(self.control_model.nets)
|
||||
# if (self.control_model is not None) and (control_image is not None):
|
||||
if control_data is not None:
|
||||
if conditioning_data.guidance_scale > 1.0:
|
||||
# expand the latents input to control model if doing classifier free guidance
|
||||
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
||||
@ -659,16 +666,31 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
latent_control_input = torch.cat([latent_model_input] * 2)
|
||||
else:
|
||||
latent_control_input = latent_model_input
|
||||
# controlnet inference
|
||||
down_block_res_samples, mid_block_res_sample = self.control_model(
|
||||
latent_control_input,
|
||||
timestep,
|
||||
# control_data should be type List[ControlNetData]
|
||||
# this loop covers both ControlNet (1 ControlNetData in list)
|
||||
# and MultiControlNet (multiple ControlNetData in list)
|
||||
for i, control_datum in enumerate(control_data):
|
||||
# print("controlnet", i, "==>", type(control_datum))
|
||||
down_samples, mid_sample = control_datum.model(
|
||||
sample=latent_control_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
|
||||
conditioning_data.text_embeddings]),
|
||||
controlnet_cond=control_image,
|
||||
conditioning_scale=control_weight,
|
||||
controlnet_cond=control_datum.image_tensor,
|
||||
conditioning_scale=control_datum.weight,
|
||||
# cross_attention_kwargs,
|
||||
guess_mode=False,
|
||||
return_dict=False,
|
||||
)
|
||||
if i == 0:
|
||||
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
||||
else:
|
||||
# add controlnet outputs together if have multiple controlnets
|
||||
down_block_res_samples = [
|
||||
samples_prev + samples_curr
|
||||
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
||||
]
|
||||
mid_block_res_sample += mid_sample
|
||||
else:
|
||||
down_block_res_samples, mid_block_res_sample = None, None
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user