WIP - Accept a list of IPAdapterFields in DenoiseLatents.

This commit is contained in:
Ryan Dick 2023-09-21 17:46:05 -04:00 committed by Kent Keirsey
parent 166ff9d301
commit 78828b6b9c
3 changed files with 48 additions and 36 deletions

View File

@ -226,7 +226,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
input=Input.Connection, input=Input.Connection,
ui_order=5, ui_order=5,
) )
ip_adapter: Optional[IPAdapterField] = InputField( ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]] = InputField(
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection, ui_order=6 description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection, ui_order=6
) )
t2i_adapter: Union[T2IAdapterField, list[T2IAdapterField]] = InputField( t2i_adapter: Union[T2IAdapterField, list[T2IAdapterField]] = InputField(
@ -410,34 +410,43 @@ class DenoiseLatentsInvocation(BaseInvocation):
def prep_ip_adapter_data( def prep_ip_adapter_data(
self, self,
context: InvocationContext, context: InvocationContext,
ip_adapter: Optional[IPAdapterField], ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
conditioning_data: ConditioningData, conditioning_data: ConditioningData,
unet: UNet2DConditionModel,
exit_stack: ExitStack, exit_stack: ExitStack,
) -> Optional[IPAdapterData]: ) -> Optional[list[IPAdapterData]]:
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings """If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
to the `conditioning_data` (in-place). to the `conditioning_data` (in-place).
""" """
if ip_adapter is None: if ip_adapter is None:
return None return None
image_encoder_model_info = context.services.model_manager.get_model( # ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
model_name=ip_adapter.image_encoder_model.model_name, if not isinstance(ip_adapter, list):
model_type=ModelType.CLIPVision, ip_adapter = [ip_adapter]
base_model=ip_adapter.image_encoder_model.base_model,
context=context,
)
if len(ip_adapter) == 0:
return None
ip_adapter_data_list = []
conditioning_data.ip_adapter_conditioning = []
for single_ip_adapter in ip_adapter:
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.services.model_manager.get_model( context.services.model_manager.get_model(
model_name=ip_adapter.ip_adapter_model.model_name, model_name=single_ip_adapter.ip_adapter_model.model_name,
model_type=ModelType.IPAdapter, model_type=ModelType.IPAdapter,
base_model=ip_adapter.ip_adapter_model.base_model, base_model=single_ip_adapter.ip_adapter_model.base_model,
context=context, context=context,
) )
) )
input_image = context.services.images.get_pil_image(ip_adapter.image.image_name) image_encoder_model_info = context.services.model_manager.get_model(
model_name=single_ip_adapter.image_encoder_model.model_name,
model_type=ModelType.CLIPVision,
base_model=single_ip_adapter.image_encoder_model.base_model,
context=context,
)
input_image = context.services.images.get_pil_image(single_ip_adapter.image.image_name)
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other # TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments. # models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
@ -446,16 +455,20 @@ class DenoiseLatentsInvocation(BaseInvocation):
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds( image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
input_image, image_encoder_model input_image, image_encoder_model
) )
conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo( conditioning_data.ip_adapter_conditioning.append(
image_prompt_embeds, uncond_image_prompt_embeds IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds)
) )
return IPAdapterData( ip_adapter_data_list.append(
IPAdapterData(
ip_adapter_model=ip_adapter_model, ip_adapter_model=ip_adapter_model,
weight=ip_adapter.weight, weight=ip_adapter.weight,
begin_step_percent=ip_adapter.begin_step_percent, begin_step_percent=ip_adapter.begin_step_percent,
end_step_percent=ip_adapter.end_step_percent, end_step_percent=ip_adapter.end_step_percent,
) )
)
return ip_adapter_data_list
def run_t2i_adapters( def run_t2i_adapters(
self, self,
@ -677,7 +690,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
context=context, context=context,
ip_adapter=self.ip_adapter, ip_adapter=self.ip_adapter,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
unet=unet,
exit_stack=exit_stack, exit_stack=exit_stack,
) )

View File

@ -336,7 +336,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[IPAdapterData] = None, ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None, t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
masked_latents: Optional[torch.Tensor] = None, masked_latents: Optional[torch.Tensor] = None,
@ -410,7 +410,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
*, *,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[IPAdapterData] = None, ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None, t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
): ):

View File

@ -81,7 +81,7 @@ class ConditioningData:
""" """
postprocessing_settings: Optional[PostprocessingSettings] = None postprocessing_settings: Optional[PostprocessingSettings] = None
ip_adapter_conditioning: Optional[IPAdapterConditioningInfo] = None ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None
@property @property
def dtype(self): def dtype(self):