mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled.
This commit is contained in:
@ -2,10 +2,12 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
import math
|
||||
import secrets
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import einops
|
||||
import PIL.Image
|
||||
@ -212,6 +214,12 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
||||
raise AssertionError("why was that an empty generator?")
|
||||
return result
|
||||
|
||||
@dataclass
|
||||
class ControlNetData:
|
||||
model: ControlNetModel = Field(default=None)
|
||||
image_tensor: torch.Tensor= Field(default=None)
|
||||
weight: float = Field(default=1.0)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConditioningData:
|
||||
@ -518,6 +526,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
additional_guidance: List[Callable] = None,
|
||||
run_id=None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||
if self.scheduler.config.get("cpu_only", False):
|
||||
@ -539,6 +548,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
additional_guidance=additional_guidance,
|
||||
run_id=run_id,
|
||||
callback=callback,
|
||||
control_data=control_data,
|
||||
**kwargs,
|
||||
)
|
||||
return result.latents, result.attention_map_saver
|
||||
@ -552,6 +562,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
noise: torch.Tensor,
|
||||
run_id: str = None,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
@ -582,7 +593,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
|
||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||
|
||||
# print("timesteps:", timesteps)
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
batched_t.fill_(t)
|
||||
step_output = self.step(
|
||||
@ -592,6 +603,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps),
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
**kwargs,
|
||||
)
|
||||
latents = step_output.prev_sample
|
||||
@ -633,11 +645,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||
timestep = t[0]
|
||||
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
|
||||
@ -645,13 +657,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
||||
|
||||
if (self.control_model is not None) and (kwargs.get("control_image") is not None):
|
||||
control_image = kwargs.get("control_image") # should be a processed tensor derived from the control image(s)
|
||||
control_weight = kwargs.get("control_weight", 1.0) # control_weight default is 1.0
|
||||
# handling case where using multiple control models but only specifying single control_weight
|
||||
# so reshape control_weight to match number of control models
|
||||
if isinstance(self.control_model, MultiControlNetModel) and isinstance(control_weight, float):
|
||||
control_weight = [control_weight] * len(self.control_model.nets)
|
||||
# if (self.control_model is not None) and (control_image is not None):
|
||||
if control_data is not None:
|
||||
if conditioning_data.guidance_scale > 1.0:
|
||||
# expand the latents input to control model if doing classifier free guidance
|
||||
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
||||
@ -659,16 +666,31 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
latent_control_input = torch.cat([latent_model_input] * 2)
|
||||
else:
|
||||
latent_control_input = latent_model_input
|
||||
# controlnet inference
|
||||
down_block_res_samples, mid_block_res_sample = self.control_model(
|
||||
latent_control_input,
|
||||
timestep,
|
||||
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
|
||||
conditioning_data.text_embeddings]),
|
||||
controlnet_cond=control_image,
|
||||
conditioning_scale=control_weight,
|
||||
return_dict=False,
|
||||
)
|
||||
# control_data should be type List[ControlNetData]
|
||||
# this loop covers both ControlNet (1 ControlNetData in list)
|
||||
# and MultiControlNet (multiple ControlNetData in list)
|
||||
for i, control_datum in enumerate(control_data):
|
||||
# print("controlnet", i, "==>", type(control_datum))
|
||||
down_samples, mid_sample = control_datum.model(
|
||||
sample=latent_control_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
|
||||
conditioning_data.text_embeddings]),
|
||||
controlnet_cond=control_datum.image_tensor,
|
||||
conditioning_scale=control_datum.weight,
|
||||
# cross_attention_kwargs,
|
||||
guess_mode=False,
|
||||
return_dict=False,
|
||||
)
|
||||
if i == 0:
|
||||
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
||||
else:
|
||||
# add controlnet outputs together if have multiple controlnets
|
||||
down_block_res_samples = [
|
||||
samples_prev + samples_curr
|
||||
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
||||
]
|
||||
mid_block_res_sample += mid_sample
|
||||
else:
|
||||
down_block_res_samples, mid_block_res_sample = None, None
|
||||
|
||||
|
Reference in New Issue
Block a user