Initial (barely) working version of IP-Adapter model management.

This commit is contained in:
Ryan Dick
2023-09-12 19:09:10 -04:00
parent 0d823901ef
commit 3ee9a21647
8 changed files with 182 additions and 85 deletions

View File

@ -171,8 +171,7 @@ class ControlNetData:
@dataclass
class IPAdapterData:
ip_adapter_model: str = Field(default=None)
image_encoder_model: str = Field(default=None)
ip_adapter_model: IPAdapter = Field(default=None)
image: PIL.Image = Field(default=None)
# TODO: change to polymorphic so can do different weights per step (once implemented...)
# weight: Union[float, List[float]] = Field(default=1.0)
@ -417,27 +416,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
return latents, attention_map_saver
if ip_adapter_data is not None:
# Initialize IPAdapter
# TODO(ryand): Refactor to use model management for the IP-Adapter.
if "plus" in ip_adapter_data.ip_adapter_model:
ip_adapter = IPAdapterPlus(
self.unet,
ip_adapter_data.image_encoder_model,
ip_adapter_data.ip_adapter_model,
self.unet.device,
num_tokens=16,
)
else:
ip_adapter = IPAdapter(
self.unet,
ip_adapter_data.image_encoder_model,
ip_adapter_data.ip_adapter_model,
self.unet.device,
)
ip_adapter.set_scale(ip_adapter_data.weight)
if not ip_adapter_data.ip_adapter_model.is_initialized():
# TODO(ryan): Do we need to initialize every time? How long does initialize take?
ip_adapter_data.ip_adapter_model.initialize(self.unet)
ip_adapter_data.ip_adapter_model.set_scale(ip_adapter_data.weight)
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_data.image)
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_data.ip_adapter_model.get_image_embeds(
ip_adapter_data.image
)
conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo(
image_prompt_embeds, uncond_image_prompt_embeds
)
@ -451,7 +438,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
elif ip_adapter_data is not None:
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
# As it is now, the IP-Adapter will silently be skipped.
attn_ctx = ip_adapter.apply_ip_adapter_attention()
attn_ctx = ip_adapter_data.ip_adapter_model.apply_ip_adapter_attention()
else:
attn_ctx = nullcontext()