mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled.
This commit is contained in:
parent
48485fe92f
commit
63d248622c
@ -20,8 +20,11 @@ from ...backend.util.devices import choose_torch_device, torch_dtype
|
|||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||||
from ...backend.image_util.seamless import configure_model_padding
|
from ...backend.image_util.seamless import configure_model_padding
|
||||||
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||||
|
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
|
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
|
from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ..services.image_file_storage import ImageType
|
from ..services.image_file_storage import ImageType
|
||||||
@ -260,9 +263,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
model = self.get_model(context.services.model_manager)
|
model = self.get_model(context.services.model_manager)
|
||||||
conditioning_data = self.get_conditioning_data(context, model)
|
conditioning_data = self.get_conditioning_data(context, model)
|
||||||
|
# print("type of control input: ", type(self.control))
|
||||||
print("type of control input: ", type(self.control))
|
|
||||||
|
|
||||||
if self.control is None:
|
if self.control is None:
|
||||||
print("control input is None")
|
print("control input is None")
|
||||||
control_list = None
|
control_list = None
|
||||||
@ -271,14 +272,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
control_list = None
|
control_list = None
|
||||||
elif isinstance(self.control, ControlField):
|
elif isinstance(self.control, ControlField):
|
||||||
print("control input is ControlField")
|
print("control input is ControlField")
|
||||||
# control = [self.control]
|
|
||||||
control_list = [self.control]
|
control_list = [self.control]
|
||||||
# elif isinstance(self.control, list) and len(self.control)>0 and isinstance(self.control[0], ControlField):
|
|
||||||
elif isinstance(self.control, list) and len(self.control) > 0 and isinstance(self.control[0], ControlField):
|
elif isinstance(self.control, list) and len(self.control) > 0 and isinstance(self.control[0], ControlField):
|
||||||
print("control input is list[ControlField]")
|
print("control input is list[ControlField]")
|
||||||
# print("using first controlnet in list")
|
|
||||||
control_list = self.control
|
control_list = self.control
|
||||||
# control = self.control
|
|
||||||
else:
|
else:
|
||||||
print("input control is unrecognized:", type(self.control))
|
print("input control is unrecognized:", type(self.control))
|
||||||
control_list = None
|
control_list = None
|
||||||
@ -286,25 +283,18 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
#if (self.control is None or (isinstance(self.control, list) and len(self.control)==0)):
|
#if (self.control is None or (isinstance(self.control, list) and len(self.control)==0)):
|
||||||
if (control_list is None):
|
if (control_list is None):
|
||||||
control_models = None
|
control_models = None
|
||||||
control_weights = None
|
|
||||||
control_images = None
|
|
||||||
# from above handling, any control that is not None should now be of type list[ControlField]
|
# from above handling, any control that is not None should now be of type list[ControlField]
|
||||||
else:
|
else:
|
||||||
# FIXME: add checks to skip entry if model or image is None
|
# FIXME: add checks to skip entry if model or image is None
|
||||||
# and if weight is None, populate with default 1.0?
|
# and if weight is None, populate with default 1.0?
|
||||||
|
control_data = []
|
||||||
control_models = []
|
control_models = []
|
||||||
control_images = []
|
|
||||||
control_weights = []
|
|
||||||
for control_info in control_list:
|
for control_info in control_list:
|
||||||
# handle control weights
|
|
||||||
control_weights.append(control_info.control_weight)
|
|
||||||
|
|
||||||
# handle control models
|
# handle control models
|
||||||
# FIXME: change this to dropdown menu
|
# FIXME: change this to dropdown menu
|
||||||
control_model = ControlNetModel.from_pretrained(control_info.control_model,
|
control_model = ControlNetModel.from_pretrained(control_info.control_model,
|
||||||
torch_dtype=model.unet.dtype).to(model.device)
|
torch_dtype=model.unet.dtype).to(model.device)
|
||||||
control_models.append(control_model)
|
control_models.append(control_model)
|
||||||
|
|
||||||
# handle control images
|
# handle control images
|
||||||
# loading controlnet image (currently requires pre-processed image)
|
# loading controlnet image (currently requires pre-processed image)
|
||||||
# control_image = prep_control_image(control_info.image)
|
# control_image = prep_control_image(control_info.image)
|
||||||
@ -313,20 +303,26 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||||
# and add in batch_size, num_images_per_prompt?
|
# and add in batch_size, num_images_per_prompt?
|
||||||
# and do real check for classifier_free_guidance?
|
# and do real check for classifier_free_guidance?
|
||||||
|
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
|
||||||
control_image = model.prepare_control_image(
|
control_image = model.prepare_control_image(
|
||||||
image=input_image,
|
image=input_image,
|
||||||
# do_classifier_free_guidance=do_classifier_free_guidance,
|
# do_classifier_free_guidance=do_classifier_free_guidance,
|
||||||
do_classifier_free_guidance=True,
|
do_classifier_free_guidance=True,
|
||||||
width=control_width_resize,
|
width=control_width_resize,
|
||||||
height=control_height_resize,
|
height=control_height_resize,
|
||||||
# batch_size=batch_size * num_images_per_prompt,
|
# batch_size=batch_size * num_images_per_prompt,
|
||||||
# num_images_per_prompt=num_images_per_prompt,
|
# num_images_per_prompt=num_images_per_prompt,
|
||||||
device=control_model.device,
|
device=control_model.device,
|
||||||
dtype=control_model.dtype,
|
dtype=control_model.dtype,
|
||||||
)
|
)
|
||||||
control_images.append(control_image)
|
control_item = ControlNetData(model=control_model,
|
||||||
multi_control = MultiControlNetModel(control_models)
|
image_tensor=control_image,
|
||||||
model.control_model = multi_control
|
weight=control_info.control_weight)
|
||||||
|
control_data.append(control_item)
|
||||||
|
# multi_control = MultiControlNetModel(control_models)
|
||||||
|
# model.control_model = multi_control
|
||||||
|
# model.control_model = control_models
|
||||||
|
|
||||||
|
|
||||||
# TODO: Verify the noise is the right size
|
# TODO: Verify the noise is the right size
|
||||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||||
@ -335,8 +331,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
num_inference_steps=self.steps,
|
num_inference_steps=self.steps,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
callback=step_callback,
|
callback=step_callback,
|
||||||
control_image=control_images,
|
control_data=control_data, # list[ControlNetData]
|
||||||
control_weight=control_weights,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
@ -2,10 +2,12 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import inspect
|
import inspect
|
||||||
|
import math
|
||||||
import secrets
|
import secrets
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
@ -212,6 +214,12 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
|||||||
raise AssertionError("why was that an empty generator?")
|
raise AssertionError("why was that an empty generator?")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ControlNetData:
|
||||||
|
model: ControlNetModel = Field(default=None)
|
||||||
|
image_tensor: torch.Tensor= Field(default=None)
|
||||||
|
weight: float = Field(default=1.0)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ConditioningData:
|
class ConditioningData:
|
||||||
@ -518,6 +526,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
run_id=None,
|
run_id=None,
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
|
control_data: List[ControlNetData] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||||
if self.scheduler.config.get("cpu_only", False):
|
if self.scheduler.config.get("cpu_only", False):
|
||||||
@ -539,6 +548,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
|
control_data=control_data,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return result.latents, result.attention_map_saver
|
return result.latents, result.attention_map_saver
|
||||||
@ -552,6 +562,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
run_id: str = None,
|
run_id: str = None,
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
|
control_data: List[ControlNetData] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self._adjust_memory_efficient_attention(latents)
|
self._adjust_memory_efficient_attention(latents)
|
||||||
@ -582,7 +593,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||||
|
|
||||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||||
|
# print("timesteps:", timesteps)
|
||||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||||
batched_t.fill_(t)
|
batched_t.fill_(t)
|
||||||
step_output = self.step(
|
step_output = self.step(
|
||||||
@ -592,6 +603,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
step_index=i,
|
step_index=i,
|
||||||
total_step_count=len(timesteps),
|
total_step_count=len(timesteps),
|
||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
|
control_data=control_data,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
@ -633,11 +645,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
step_index: int,
|
step_index: int,
|
||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
|
control_data: List[ControlNetData] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||||
timestep = t[0]
|
timestep = t[0]
|
||||||
|
|
||||||
if additional_guidance is None:
|
if additional_guidance is None:
|
||||||
additional_guidance = []
|
additional_guidance = []
|
||||||
|
|
||||||
@ -645,13 +657,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||||
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
||||||
|
|
||||||
if (self.control_model is not None) and (kwargs.get("control_image") is not None):
|
# if (self.control_model is not None) and (control_image is not None):
|
||||||
control_image = kwargs.get("control_image") # should be a processed tensor derived from the control image(s)
|
if control_data is not None:
|
||||||
control_weight = kwargs.get("control_weight", 1.0) # control_weight default is 1.0
|
|
||||||
# handling case where using multiple control models but only specifying single control_weight
|
|
||||||
# so reshape control_weight to match number of control models
|
|
||||||
if isinstance(self.control_model, MultiControlNetModel) and isinstance(control_weight, float):
|
|
||||||
control_weight = [control_weight] * len(self.control_model.nets)
|
|
||||||
if conditioning_data.guidance_scale > 1.0:
|
if conditioning_data.guidance_scale > 1.0:
|
||||||
# expand the latents input to control model if doing classifier free guidance
|
# expand the latents input to control model if doing classifier free guidance
|
||||||
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
||||||
@ -659,16 +666,31 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
latent_control_input = torch.cat([latent_model_input] * 2)
|
latent_control_input = torch.cat([latent_model_input] * 2)
|
||||||
else:
|
else:
|
||||||
latent_control_input = latent_model_input
|
latent_control_input = latent_model_input
|
||||||
# controlnet inference
|
# control_data should be type List[ControlNetData]
|
||||||
down_block_res_samples, mid_block_res_sample = self.control_model(
|
# this loop covers both ControlNet (1 ControlNetData in list)
|
||||||
latent_control_input,
|
# and MultiControlNet (multiple ControlNetData in list)
|
||||||
timestep,
|
for i, control_datum in enumerate(control_data):
|
||||||
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
|
# print("controlnet", i, "==>", type(control_datum))
|
||||||
conditioning_data.text_embeddings]),
|
down_samples, mid_sample = control_datum.model(
|
||||||
controlnet_cond=control_image,
|
sample=latent_control_input,
|
||||||
conditioning_scale=control_weight,
|
timestep=timestep,
|
||||||
return_dict=False,
|
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
|
||||||
)
|
conditioning_data.text_embeddings]),
|
||||||
|
controlnet_cond=control_datum.image_tensor,
|
||||||
|
conditioning_scale=control_datum.weight,
|
||||||
|
# cross_attention_kwargs,
|
||||||
|
guess_mode=False,
|
||||||
|
return_dict=False,
|
||||||
|
)
|
||||||
|
if i == 0:
|
||||||
|
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
||||||
|
else:
|
||||||
|
# add controlnet outputs together if have multiple controlnets
|
||||||
|
down_block_res_samples = [
|
||||||
|
samples_prev + samples_curr
|
||||||
|
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
||||||
|
]
|
||||||
|
mid_block_res_sample += mid_sample
|
||||||
else:
|
else:
|
||||||
down_block_res_samples, mid_block_res_sample = None, None
|
down_block_res_samples, mid_block_res_sample = None, None
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user