cleanup: merge conflicts

This commit is contained in:
blessedcoolant
2023-09-05 11:37:12 +12:00
parent 6bb378a101
commit 07381e5a26
8 changed files with 86 additions and 69 deletions

View File

@ -179,6 +179,7 @@ class IPAdapterData:
# weight: Union[float, List[float]] = Field(default=1.0)
weight: float = Field(default=1.0)
@dataclass
class ConditioningData:
unconditioned_embeddings: BasicConditioningInfo
@ -442,7 +443,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
ip_adapter_data: List[IPAdapterData] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
):
self._adjust_memory_efficient_attention(latents)
if additional_guidance is None:
additional_guidance = []
@ -469,30 +469,33 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
#
if "sdxl" in ip_adapter_info.ip_adapter_model:
print("using IPAdapterXL")
ip_adapter = IPAdapterXL(self,
ip_adapter_info.image_encoder_model,
ip_adapter_info.ip_adapter_model,
self.unet.device)
ip_adapter = IPAdapterXL(
self, ip_adapter_info.image_encoder_model, ip_adapter_info.ip_adapter_model, self.unet.device
)
elif "plus" in ip_adapter_info.ip_adapter_model:
print("using IPAdapterPlus")
ip_adapter = IPAdapterPlus(self, # IPAdapterPlus first arg is StableDiffusionPipeline
ip_adapter_info.image_encoder_model,
ip_adapter_info.ip_adapter_model,
self.unet.device,
num_tokens=16)
ip_adapter = IPAdapterPlus(
self, # IPAdapterPlus first arg is StableDiffusionPipeline
ip_adapter_info.image_encoder_model,
ip_adapter_info.ip_adapter_model,
self.unet.device,
num_tokens=16,
)
else:
print("using IPAdapter")
ip_adapter = IPAdapter(self, # IPAdapter first arg is StableDiffusionPipeline
ip_adapter_info.image_encoder_model,
ip_adapter_info.ip_adapter_model,
self.unet.device)
ip_adapter = IPAdapter(
self, # IPAdapter first arg is StableDiffusionPipeline
ip_adapter_info.image_encoder_model,
ip_adapter_info.ip_adapter_model,
self.unet.device,
)
# IP-Adapter ==> add additional cross-attention layers to UNet model here?
ip_adapter.set_scale(ip_adapter_info.weight)
print("ip_adapter:", ip_adapter)
# get image embedding from CLIP and ImageProjModel
print("getting image embeddings from IP-Adapter...")
num_samples = 1 # hardwiring for first pass
num_samples = 1 # hardwiring for first pass
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_image)
print("image cond embeds shape:", image_prompt_embeds.shape)
print("image uncond embeds shape:", uncond_image_prompt_embeds.shape)