First commit of separate node for IP-Adapter.

And it own dataclasses for passing info.
This commit is contained in:
user1 2023-08-31 23:07:15 -07:00
parent 942ecbbde4
commit 74bfb5e1f9
3 changed files with 171 additions and 75 deletions

View File

@ -19,6 +19,7 @@ from .baseinvocation import (
invocation_output, invocation_output,
) )
CONTROL_ADAPTER_TYPES = Literal["ControlNet", "IP-Adapter", "T2I-Adapter"]
CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"] CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]
CONTROLNET_RESIZE_VALUES = Literal[ CONTROLNET_RESIZE_VALUES = Literal[
@ -36,8 +37,15 @@ class ControlNetModelField(BaseModel):
class ControlField(BaseModel): class ControlField(BaseModel):
control_type: CONTROL_ADAPTER_TYPES = Field(default="ControlNet", description="The type of control adapter")
image: ImageField = Field(description="The control image") image: ImageField = Field(description="The control image")
control_model: ControlNetModelField = Field(description="The ControlNet model to use") # control_model and ip_adapter_models are both optional
# but must be on the two present
# if control_type == "ControlNet", then mus be control_model
# if control_type == "IP-Adapter", then must be ip_adapter_model
control_model: Optional[ControlNetModelField] = Field(description="The ControlNet model to use")
ip_adapter_model: Optional[str] = Field(description="The IP-Adapter model to use")
image_encoder_model: Optional[str] = Field(description="The clip_image_encoder model to use")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field( begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)" default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
@ -97,8 +105,12 @@ class ControlNetInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ControlOutput: def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput( return ControlOutput(
control=ControlField( control=ControlField(
control_type="ControlNet",
image=self.image, image=self.image,
control_model=self.control_model, control_model=self.control_model,
# ip_adapter_model is currently optional
# must be either a control_model or ip_adapter_model
# ip_adapter_model=None,
control_weight=self.control_weight, control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent, begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent, end_step_percent=self.end_step_percent,
@ -106,3 +118,63 @@ class ControlNetInvocation(BaseInvocation):
resize_mode=self.resize_mode, resize_mode=self.resize_mode,
), ),
) )
IP_ADAPTER_MODELS = Literal[
"models_ip_adapter/models/ip-adapter_sd15.bin",
"models_ip_adapter/models/ip-adapter-plus_sd15.bin",
"models_ip_adapter/models/ip-adapter-plus-face_sd15.bin",
"models_ip_adapter/sdxl_models/ip-adapter_sdxl.bin"
]
IP_ADAPTER_IMAGE_ENCODER_MODELS = Literal[
"models_ip_adapter/models/image_encoder/",
"./models_ip_adapter/models/image_encoder/",
"models_ip_adapter/sdxl_models/image_encoder/"
]
@invocation("ipadapter", title="IP-Adapter", tags=["ipadapter"], category="ipadapter")
class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes"""
type: Literal["ipadapter"] = "ipadapter"
# Inputs
image: ImageField = InputField(description="The control image")
#control_model: ControlNetModelField = InputField(
# default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
#)
ip_adapter_model: IP_ADAPTER_MODELS = InputField(default="./models_ip_adapter/models/ip-adapter_sd15.bin",
description="The IP-Adapter model")
image_encoder_model: IP_ADAPTER_IMAGE_ENCODER_MODELS = InputField(
default="./models_ip_adapter/models/image_encoder/",
description="The image encoder model")
control_weight: Union[float, List[float]] = InputField(
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
)
# begin_step_percent: float = InputField(
# default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
# )
# end_step_percent: float = InputField(
# default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
# )
# control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
# resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput(
control=ControlField(
control_type="IP-Adapter",
image=self.image,
# control_model is currently optional
# must be either a control_model or ip_adapter_model
# control_model=None,
ip_adapter_model=self.ip_adapter_model,
image_encoder_model=self.image_encoder_model,
control_weight=self.control_weight,
# rest are currently ignored
#begin_step_percent=self.begin_step_percent,
#end_step_percent=self.end_step_percent,
#control_mode=self.control_mode,
#resize_mode=self.resize_mode,
),
)

View File

@ -40,6 +40,7 @@ from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData, ConditioningData,
ControlNetData, ControlNetData,
IPAdapterData,
StableDiffusionGeneratorPipeline, StableDiffusionGeneratorPipeline,
image_resized_to_grid_as_tensor, image_resized_to_grid_as_tensor,
) )
@ -216,9 +217,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
default=None, default=None,
description=FieldDescriptions.mask, description=FieldDescriptions.mask,
) )
ip_adapter_image: Optional[ImageField] = InputField(input=Input.Connection, title="IP Adapter Image", ui_order=6) # ip_adapter_image: Optional[ImageField] = InputField(input=Input.Connection, title="IP Adapter Image", ui_order=6)
ip_adapter_strength: float = InputField(default=1.0, ge=0, le=2, ui_type=UIType.Float, # ip_adapter_strength: float = InputField(default=1.0, ge=0, le=2, ui_type=UIType.Float,
title="IP Adapter Strength", ui_order=7) # title="IP Adapter Strength", ui_order=7)
@validator("cfg_scale") @validator("cfg_scale")
def ge_one(cls, v): def ge_one(cls, v):
@ -340,57 +341,71 @@ class DenoiseLatentsInvocation(BaseInvocation):
else: else:
control_list = None control_list = None
if control_list is None: if control_list is None:
control_data = None controlnet_data = None
ip_adapter_data = None
# from above handling, any control that is not None should now be of type list[ControlField] # from above handling, any control that is not None should now be of type list[ControlField]
else: else:
# FIXME: add checks to skip entry if model or image is None # FIXME: add checks to skip entry if model or image is None
# and if weight is None, populate with default 1.0? # and if weight is None, populate with default 1.0?
control_data = [] controlnet_data = []
control_models = [] ip_adapter_data = []
# control_models = []
for control_info in control_list: for control_info in control_list:
control_model = exit_stack.enter_context( if control_info.control_type == "ControlNet":
context.services.model_manager.get_model( control_model = exit_stack.enter_context(
model_name=control_info.control_model.model_name, context.services.model_manager.get_model(
model_type=ModelType.ControlNet, model_name=control_info.control_model.model_name,
base_model=control_info.control_model.base_model, model_type=ModelType.ControlNet,
context=context, base_model=control_info.control_model.base_model,
context=context,
)
) )
)
control_models.append(control_model) # control_models.append(control_model)
control_image_field = control_info.image control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_name) input_image = context.services.images.get_pil_image(control_image_field.image_name)
# self.image.image_type, self.image.image_name # self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes # FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt? # and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance? # and do real check for classifier_free_guidance?
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width) # prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
control_image = prepare_control_image( control_image = prepare_control_image(
image=input_image, image=input_image,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=do_classifier_free_guidance,
width=control_width_resize, width=control_width_resize,
height=control_height_resize, height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt, # batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt, # num_images_per_prompt=num_images_per_prompt,
device=control_model.device, device=control_model.device,
dtype=control_model.dtype, dtype=control_model.dtype,
control_mode=control_info.control_mode, control_mode=control_info.control_mode,
resize_mode=control_info.resize_mode, resize_mode=control_info.resize_mode,
) )
control_item = ControlNetData( control_item = ControlNetData(
model=control_model, model=control_model, # model object
image_tensor=control_image, image_tensor=control_image,
weight=control_info.control_weight, weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent, begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent, end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode, control_mode=control_info.control_mode,
# any resizing needed should currently be happening in prepare_control_image(), # any resizing needed should currently be happening in prepare_control_image(),
# but adding resize_mode to ControlNetData in case needed in the future # but adding resize_mode to ControlNetData in case needed in the future
resize_mode=control_info.resize_mode, resize_mode=control_info.resize_mode,
) )
control_data.append(control_item) controlnet_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData] # MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data elif control_info.control_type == "IP-Adapter":
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_name)
control_item = IPAdapterData(
ip_adapter_model=control_info.ip_adapter_model, # name of model (NOT model object)
image_encoder_model=control_info.image_encoder_model, # name of model (NOT model obj)
image=input_image,
weight=control_info.control_weight,
)
ip_adapter_data.append(control_item)
return controlnet_data, ip_adapter_data
# original idea by https://github.com/AmericanPresidentJimmyCarter # original idea by https://github.com/AmericanPresidentJimmyCarter
# TODO: research more for second order schedulers timesteps # TODO: research more for second order schedulers timesteps
@ -499,14 +514,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
pipeline = self.create_pipeline(unet, scheduler) pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed) conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
if self.ip_adapter_image is not None: # if self.ip_adapter_image is not None:
print("ip_adapter_image:", self.ip_adapter_image) # print("ip_adapter_image:", self.ip_adapter_image)
unwrapped_ip_adapter_image = context.services.images.get_pil_image(self.ip_adapter_image.image_name) # unwrapped_ip_adapter_image = context.services.images.get_pil_image(self.ip_adapter_image.image_name)
print("unwrapped ip_adapter_image:", unwrapped_ip_adapter_image) # print("unwrapped ip_adapter_image:", unwrapped_ip_adapter_image)
else: # else:
unwrapped_ip_adapter_image = None # unwrapped_ip_adapter_image = None
control_data = self.prep_control_data( controlnet_data, ip_adapter_data = self.prep_control_data(
model=pipeline, model=pipeline,
context=context, context=context,
control_input=self.control, control_input=self.control,
@ -515,6 +530,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
exit_stack=exit_stack, exit_stack=exit_stack,
) )
print("controlnet_data:", controlnet_data)
print("ip_adapter_data:", ip_adapter_data)
num_inference_steps, timesteps, init_timestep = self.init_scheduler( num_inference_steps, timesteps, init_timestep = self.init_scheduler(
scheduler, scheduler,
@ -534,9 +551,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
masked_latents=masked_latents, masked_latents=masked_latents,
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData], control_data=controlnet_data, # list[ControlNetData],
ip_adapter_image=unwrapped_ip_adapter_image, ip_adapter_data=ip_adapter_data, # list[IPAdapterData],
ip_adapter_strength=self.ip_adapter_strength, # ip_adapter_image=unwrapped_ip_adapter_image,
# ip_adapter_strength=self.ip_adapter_strength,
callback=step_callback, callback=step_callback,
) )

View File

@ -170,6 +170,15 @@ class ControlNetData:
resize_mode: str = Field(default="just_resize") resize_mode: str = Field(default="just_resize")
@dataclass
class IPAdapterData:
ip_adapter_model: str = Field(default=None)
image_encoder_model: str = Field(default=None)
image: PIL.Image = Field(default=None)
# TODO: change to polymorphic so can do different weights per step (once implemented...)
# weight: Union[float, List[float]] = Field(default=1.0)
weight: float = Field(default=1.0)
@dataclass @dataclass
class ConditioningData: class ConditioningData:
unconditioned_embeddings: BasicConditioningInfo unconditioned_embeddings: BasicConditioningInfo
@ -358,8 +367,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
ip_adapter_image: Optional[PIL.Image] = None, ip_adapter_data: IPAdapterData = None,
ip_adapter_strength: float = 1.0,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
masked_latents: Optional[torch.Tensor] = None, masked_latents: Optional[torch.Tensor] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
@ -411,8 +419,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data, conditioning_data,
additional_guidance=additional_guidance, additional_guidance=additional_guidance,
control_data=control_data, control_data=control_data,
ip_adapter_image=ip_adapter_image, ip_adapter_data=ip_adapter_data,
ip_adapter_strength=ip_adapter_strength,
callback=callback, callback=callback,
) )
finally: finally:
@ -432,8 +439,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
*, *,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
ip_adapter_image: Optional[PIL.Image] = None, ip_adapter_data: List[IPAdapterData] = None,
ip_adapter_strength: float = 1.0,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
): ):
@ -447,26 +453,26 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if timesteps.shape[0] == 0: if timesteps.shape[0] == 0:
return latents, attention_map_saver return latents, attention_map_saver
print("ip_adapter_image: ", type(ip_adapter_image)) # print("ip_adapter_image: ", type(ip_adapter_image))
if ip_adapter_image is not None: if ip_adapter_data is not None and len(ip_adapter_data) > 0:
ip_adapter_info = ip_adapter_data[0]
ip_adapter_image = ip_adapter_info.image
# initialize IPAdapter # initialize IPAdapter
print(" width:", ip_adapter_image.width, " height:", ip_adapter_image.height) print(" width:", ip_adapter_image.width, " height:", ip_adapter_image.height)
clip_image_encoder_path = "ip_adapter_models_sd_1.5/image_encoder/"
ip_adapter_model_path = "ip_adapter_models_sd_1.5/ip-adapter_sd15.bin"
# FIXME: # FIXME:
# WARNING! # WARNING!
# IPAdapter constructor modifies UNet model in-place # IPAdapter constructor modifies UNet model in-place
# Adds additional cross-attention layers to UNet model for image embedding # Adds additional cross-attention layers to UNet model for image embedding
# need to figure out how to only do this if UNet hasn't already been modified by prior IPAdapter # need to figure out how to only do this if UNet hasn't already been modified by prior IPAdapter
# and how to undo if ip_adapter_image is removed # and how to undo if ip_adapter_image is removed
# use existing model management context etc? # Should reimplement to use existing model management context etc.
# #
ip_adapter = IPAdapter(self, # IPAdapter first arg is StableDiffusionPipeline ip_adapter = IPAdapter(self, # IPAdapter first arg is StableDiffusionPipeline
clip_image_encoder_path, # hardwiring to manually downloaded dir for first pass ip_adapter_info.image_encoder_model,
ip_adapter_model_path, # hardwiring to manually downloaded loc for first pass ip_adapter_info.ip_adapter_model,
"cuda") # hardwiring CUDA GPU for first pass self.unet.device)
# IP-Adapter ==> add additional cross-attention layers to UNet model here? # IP-Adapter ==> add additional cross-attention layers to UNet model here?
ip_adapter.set_scale(ip_adapter_strength) ip_adapter.set_scale(ip_adapter_info.weight)
print("ip_adapter:", ip_adapter) print("ip_adapter:", ip_adapter)
# get image embedding from CLIP and ImageProjModel # get image embedding from CLIP and ImageProjModel