mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
First commit of separate node for IP-Adapter.
And it own dataclasses for passing info.
This commit is contained in:
@ -170,6 +170,15 @@ class ControlNetData:
|
||||
resize_mode: str = Field(default="just_resize")
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPAdapterData:
|
||||
ip_adapter_model: str = Field(default=None)
|
||||
image_encoder_model: str = Field(default=None)
|
||||
image: PIL.Image = Field(default=None)
|
||||
# TODO: change to polymorphic so can do different weights per step (once implemented...)
|
||||
# weight: Union[float, List[float]] = Field(default=1.0)
|
||||
weight: float = Field(default=1.0)
|
||||
|
||||
@dataclass
|
||||
class ConditioningData:
|
||||
unconditioned_embeddings: BasicConditioningInfo
|
||||
@ -358,8 +367,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
additional_guidance: List[Callable] = None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
ip_adapter_image: Optional[PIL.Image] = None,
|
||||
ip_adapter_strength: float = 1.0,
|
||||
ip_adapter_data: IPAdapterData = None,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
masked_latents: Optional[torch.Tensor] = None,
|
||||
seed: Optional[int] = None,
|
||||
@ -411,8 +419,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
conditioning_data,
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
ip_adapter_image=ip_adapter_image,
|
||||
ip_adapter_strength=ip_adapter_strength,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
callback=callback,
|
||||
)
|
||||
finally:
|
||||
@ -432,8 +439,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
*,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
ip_adapter_image: Optional[PIL.Image] = None,
|
||||
ip_adapter_strength: float = 1.0,
|
||||
ip_adapter_data: List[IPAdapterData] = None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
):
|
||||
|
||||
@ -447,26 +453,26 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if timesteps.shape[0] == 0:
|
||||
return latents, attention_map_saver
|
||||
|
||||
print("ip_adapter_image: ", type(ip_adapter_image))
|
||||
if ip_adapter_image is not None:
|
||||
# print("ip_adapter_image: ", type(ip_adapter_image))
|
||||
if ip_adapter_data is not None and len(ip_adapter_data) > 0:
|
||||
ip_adapter_info = ip_adapter_data[0]
|
||||
ip_adapter_image = ip_adapter_info.image
|
||||
# initialize IPAdapter
|
||||
print(" width:", ip_adapter_image.width, " height:", ip_adapter_image.height)
|
||||
clip_image_encoder_path = "ip_adapter_models_sd_1.5/image_encoder/"
|
||||
ip_adapter_model_path = "ip_adapter_models_sd_1.5/ip-adapter_sd15.bin"
|
||||
# FIXME:
|
||||
# WARNING!
|
||||
# IPAdapter constructor modifies UNet model in-place
|
||||
# Adds additional cross-attention layers to UNet model for image embedding
|
||||
# need to figure out how to only do this if UNet hasn't already been modified by prior IPAdapter
|
||||
# and how to undo if ip_adapter_image is removed
|
||||
# use existing model management context etc?
|
||||
# need to figure out how to only do this if UNet hasn't already been modified by prior IPAdapter
|
||||
# and how to undo if ip_adapter_image is removed
|
||||
# Should reimplement to use existing model management context etc.
|
||||
#
|
||||
ip_adapter = IPAdapter(self, # IPAdapter first arg is StableDiffusionPipeline
|
||||
clip_image_encoder_path, # hardwiring to manually downloaded dir for first pass
|
||||
ip_adapter_model_path, # hardwiring to manually downloaded loc for first pass
|
||||
"cuda") # hardwiring CUDA GPU for first pass
|
||||
ip_adapter_info.image_encoder_model,
|
||||
ip_adapter_info.ip_adapter_model,
|
||||
self.unet.device)
|
||||
# IP-Adapter ==> add additional cross-attention layers to UNet model here?
|
||||
ip_adapter.set_scale(ip_adapter_strength)
|
||||
ip_adapter.set_scale(ip_adapter_info.weight)
|
||||
print("ip_adapter:", ip_adapter)
|
||||
|
||||
# get image embedding from CLIP and ImageProjModel
|
||||
|
Reference in New Issue
Block a user