Comment unused IPAdapter generate(...) methods.

This commit is contained in:
Ryan Dick 2023-09-08 13:07:57 -04:00
parent b2d5b53b5f
commit d669f0855d

View File

@ -1,8 +1,6 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0) # copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed # and modified as needed
from typing import List
import torch import torch
# FIXME: Getting errors when trying to use PyTorch 2.0 versions of IPAttnProcessor and AttnProcessor # FIXME: Getting errors when trying to use PyTorch 2.0 versions of IPAttnProcessor and AttnProcessor
@ -120,134 +118,135 @@ class IPAdapter:
if isinstance(attn_processor, IPAttnProcessor): if isinstance(attn_processor, IPAttnProcessor):
attn_processor.scale = scale attn_processor.scale = scale
# IPAdapter.generate() method is not used for InvokeAI # IPAdapter.generate() method is not used for InvokeAI. Left here for reference:
# left here for reference # def generate(
def generate( # self,
self, # pil_image,
pil_image, # prompt=None,
prompt=None, # negative_prompt=None,
negative_prompt=None, # scale=1.0,
scale=1.0, # num_samples=4,
num_samples=4, # seed=-1,
seed=-1, # guidance_scale=7.5,
guidance_scale=7.5, # num_inference_steps=30,
num_inference_steps=30, # **kwargs,
**kwargs, # ):
): # self.set_scale(scale)
self.set_scale(scale)
if isinstance(pil_image, Image.Image): # if isinstance(pil_image, Image.Image):
num_prompts = 1 # num_prompts = 1
else: # else:
num_prompts = len(pil_image) # num_prompts = len(pil_image)
if prompt is None: # if prompt is None:
prompt = "best quality, high quality" # prompt = "best quality, high quality"
if negative_prompt is None: # if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" # negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List): # if not isinstance(prompt, List):
prompt = [prompt] * num_prompts # prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List): # if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts # negative_prompt = [negative_prompt] * num_prompts
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) # image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
bs_embed, seq_len, _ = image_prompt_embeds.shape # bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) # image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) # image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) # uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) # uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
with torch.inference_mode(): # with torch.inference_mode():
prompt_embeds = self.pipe._encode_prompt( # prompt_embeds = self.pipe._encode_prompt(
prompt, # prompt,
device=self.device, # device=self.device,
num_images_per_prompt=num_samples, # num_images_per_prompt=num_samples,
do_classifier_free_guidance=True, # do_classifier_free_guidance=True,
negative_prompt=negative_prompt, # negative_prompt=negative_prompt,
) # )
negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2) # negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) # prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) # negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None # generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
images = self.pipe( # images = self.pipe(
prompt_embeds=prompt_embeds, # prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, # negative_prompt_embeds=negative_prompt_embeds,
guidance_scale=guidance_scale, # guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps, # num_inference_steps=num_inference_steps,
generator=generator, # generator=generator,
**kwargs, # **kwargs,
).images # ).images
return images # return images
class IPAdapterXL(IPAdapter): class IPAdapterXL(IPAdapter):
"""SDXL""" """SDXL"""
def generate( pass
self, # IPAdapterXL.generate() method is not used for InvokeAI. Left here for reference:
pil_image, # def generate(
prompt=None, # self,
negative_prompt=None, # pil_image,
scale=1.0, # prompt=None,
num_samples=4, # negative_prompt=None,
seed=-1, # scale=1.0,
num_inference_steps=30, # num_samples=4,
**kwargs, # seed=-1,
): # num_inference_steps=30,
self.set_scale(scale) # **kwargs,
# ):
# self.set_scale(scale)
if isinstance(pil_image, Image.Image): # if isinstance(pil_image, Image.Image):
num_prompts = 1 # num_prompts = 1
else: # else:
num_prompts = len(pil_image) # num_prompts = len(pil_image)
if prompt is None: # if prompt is None:
prompt = "best quality, high quality" # prompt = "best quality, high quality"
if negative_prompt is None: # if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" # negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List): # if not isinstance(prompt, List):
prompt = [prompt] * num_prompts # prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List): # if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts # negative_prompt = [negative_prompt] * num_prompts
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) # image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
bs_embed, seq_len, _ = image_prompt_embeds.shape # bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) # image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) # image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) # uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) # uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
with torch.inference_mode(): # with torch.inference_mode():
( # (
prompt_embeds, # prompt_embeds,
negative_prompt_embeds, # negative_prompt_embeds,
pooled_prompt_embeds, # pooled_prompt_embeds,
negative_pooled_prompt_embeds, # negative_pooled_prompt_embeds,
) = self.pipe.encode_prompt( # ) = self.pipe.encode_prompt(
prompt, # prompt,
num_images_per_prompt=num_samples, # num_images_per_prompt=num_samples,
do_classifier_free_guidance=True, # do_classifier_free_guidance=True,
negative_prompt=negative_prompt, # negative_prompt=negative_prompt,
) # )
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) # prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) # negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None # generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
images = self.pipe( # images = self.pipe(
prompt_embeds=prompt_embeds, # prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, # negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds, # pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, # negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
num_inference_steps=num_inference_steps, # num_inference_steps=num_inference_steps,
generator=generator, # generator=generator,
**kwargs, # **kwargs,
).images # ).images
return images # return images
class IPAdapterPlus(IPAdapter): class IPAdapterPlus(IPAdapter):