mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
diffusers: use InvokeAIDiffuserComponent for conditioning
This commit is contained in:
parent
97dd4a2589
commit
b6b1a8d97c
@ -1,4 +1,5 @@
|
|||||||
import secrets
|
import secrets
|
||||||
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Union, Callable
|
from typing import List, Optional, Union, Callable
|
||||||
|
|
||||||
@ -10,6 +11,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS
|
|||||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
|
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
from ldm.modules.encoders.modules import WeightedFrozenCLIPEmbedder
|
from ldm.modules.encoders.modules import WeightedFrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
@ -82,6 +84,7 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
|||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
transformer=self.text_encoder
|
transformer=self.text_encoder
|
||||||
)
|
)
|
||||||
|
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
|
||||||
|
|
||||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||||
r"""
|
r"""
|
||||||
@ -128,72 +131,36 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
|||||||
"""
|
"""
|
||||||
self.unet.set_use_memory_efficient_attention_xformers(False)
|
self.unet.set_use_memory_efficient_attention_xformers(False)
|
||||||
|
|
||||||
@torch.no_grad()
|
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
||||||
def __call__(
|
text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor,
|
||||||
self,
|
guidance_scale: float,
|
||||||
prompt: Union[str, List[str]],
|
*, callback: Callable[[PipelineIntermediateState], None]=None,
|
||||||
height: Optional[int] = 512,
|
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None,
|
||||||
width: Optional[int] = 512,
|
run_id=None,
|
||||||
num_inference_steps: Optional[int] = 50,
|
**extra_step_kwargs) -> StableDiffusionPipelineOutput:
|
||||||
guidance_scale: Optional[float] = 7.5,
|
|
||||||
generator: Optional[torch.Generator] = None,
|
|
||||||
latents: Optional[torch.FloatTensor] = None,
|
|
||||||
callback: Optional[Callable[[PipelineIntermediateState], None]] = None,
|
|
||||||
**extra_step_kwargs,
|
|
||||||
):
|
|
||||||
r"""
|
r"""
|
||||||
Function invoked when calling the pipeline for generation.
|
Function invoked when calling the pipeline for generation.
|
||||||
|
|
||||||
Args:
|
:param latents: Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for
|
||||||
prompt (`str` or `List[str]`):
|
image generation. Can be used to tweak the same generation with different prompts.
|
||||||
The prompt or prompts to guide the image generation.
|
:param num_inference_steps: The number of denoising steps. More denoising steps usually lead to a higher quality
|
||||||
height (`int`, *optional*, defaults to 512):
|
image at the expense of slower inference.
|
||||||
The height in pixels of the generated image.
|
:param text_embeddings:
|
||||||
width (`int`, *optional*, defaults to 512):
|
:param guidance_scale: Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
The width in pixels of the generated image.
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||||
expense of slower inference.
|
:param callback:
|
||||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
:param extra_conditioning_info:
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
:param run_id:
|
||||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
:param extra_step_kwargs:
|
||||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
|
||||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
|
||||||
usually at the expense of lower image quality.
|
|
||||||
generator (`torch.Generator`, *optional*):
|
|
||||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
|
||||||
deterministic.
|
|
||||||
latents (`torch.FloatTensor`, *optional*):
|
|
||||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
|
||||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
|
||||||
tensor will ge generated by sampling using the supplied random `generator`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
|
||||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
|
||||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
|
||||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
|
||||||
(nsfw) content, according to the `safety_checker`.
|
|
||||||
"""
|
"""
|
||||||
result = None
|
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
|
||||||
for result in self.generate(
|
|
||||||
prompt, height=height, width=width, num_inference_steps=num_inference_steps,
|
|
||||||
guidance_scale=guidance_scale, generator=generator, latents=latents,
|
|
||||||
**extra_step_kwargs):
|
|
||||||
if callback is not None:
|
|
||||||
callback(result)
|
|
||||||
if result is None:
|
|
||||||
raise AssertionError("why was that an empty generator?")
|
|
||||||
return result
|
|
||||||
|
|
||||||
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
|
||||||
text_embeddings: torch.Tensor, guidance_scale: float,
|
|
||||||
*, callback: Callable[[PipelineIntermediateState], None]=None, run_id=None,
|
|
||||||
**extra_step_kwargs) -> StableDiffusionPipelineOutput:
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps)
|
|
||||||
result = None
|
result = None
|
||||||
for result in self.generate_from_embeddings(
|
for result in self.generate_from_embeddings(
|
||||||
latents, text_embeddings, guidance_scale, run_id, **extra_step_kwargs):
|
latents, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||||
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
|
run_id=run_id, **extra_step_kwargs):
|
||||||
if callback is not None and isinstance(result, PipelineIntermediateState):
|
if callback is not None and isinstance(result, PipelineIntermediateState):
|
||||||
callback(result)
|
callback(result)
|
||||||
if result is None:
|
if result is None:
|
||||||
@ -226,24 +193,40 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
|||||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||||
# corresponds to doing no classifier free guidance.
|
# corresponds to doing no classifier free guidance.
|
||||||
do_classifier_free_guidance = guidance_scale > 1.0
|
do_classifier_free_guidance = guidance_scale > 1.0
|
||||||
text_embeddings = self.get_text_embeddings(prompt, opposing_prompt, do_classifier_free_guidance, batch_size)\
|
text_embeddings, unconditioned_embeddings = self.get_text_embeddings(prompt, opposing_prompt, do_classifier_free_guidance, batch_size)\
|
||||||
.to(self.unet.device)
|
.to(self.unet.device)
|
||||||
self.scheduler.set_timesteps(num_inference_steps)
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
latents = self.prepare_latents(latents, batch_size, height, width, generator, self.unet.dtype)
|
latents = self.prepare_latents(latents, batch_size, height, width, generator, self.unet.dtype)
|
||||||
|
|
||||||
yield from self.generate_from_embeddings(latents, text_embeddings, guidance_scale, run_id, **extra_step_kwargs)
|
yield from self.generate_from_embeddings(latents, text_embeddings, unconditioned_embeddings,
|
||||||
|
guidance_scale, run_id=run_id, **extra_step_kwargs)
|
||||||
|
|
||||||
def generate_from_embeddings(self, latents: torch.Tensor, text_embeddings: torch.Tensor, guidance_scale: float,
|
def generate_from_embeddings(
|
||||||
run_id: str = None, **extra_step_kwargs):
|
self,
|
||||||
|
latents: torch.Tensor,
|
||||||
|
text_embeddings: torch.Tensor,
|
||||||
|
unconditioned_embeddings: torch.Tensor,
|
||||||
|
guidance_scale: float,
|
||||||
|
*,
|
||||||
|
run_id: str = None,
|
||||||
|
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
|
||||||
|
**extra_step_kwargs):
|
||||||
if run_id is None:
|
if run_id is None:
|
||||||
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
||||||
# scale the initial noise by the standard deviation required by the scheduler
|
# scale the initial noise by the standard deviation required by the scheduler
|
||||||
latents *= self.scheduler.init_noise_sigma
|
latents *= self.scheduler.init_noise_sigma
|
||||||
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
|
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
|
||||||
latents=latents)
|
latents=latents)
|
||||||
|
|
||||||
|
batch_size = latents.shape[0]
|
||||||
|
batched_t = torch.full((batch_size,), self.scheduler.timesteps[0],
|
||||||
|
dtype=self.scheduler.timesteps.dtype, device=self.unet.device)
|
||||||
# NOTE: Depends on scheduler being already initialized!
|
# NOTE: Depends on scheduler being already initialized!
|
||||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||||
step_output = self.step(t, latents, guidance_scale, text_embeddings, **extra_step_kwargs)
|
batched_t.fill_(t)
|
||||||
|
step_output = self.step(batched_t, latents, guidance_scale,
|
||||||
|
text_embeddings, unconditioned_embeddings,
|
||||||
|
i, **extra_step_kwargs)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
||||||
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
|
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
|
||||||
@ -257,23 +240,30 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
|||||||
yield self.check_for_safety(output)
|
yield self.check_for_safety(output)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def step(self, t, latents: torch.Tensor, guidance_scale, text_embeddings: torch.Tensor, **extra_step_kwargs):
|
def step(self, t: torch.Tensor, latents: torch.Tensor, guidance_scale: float,
|
||||||
do_classifier_free_guidance = guidance_scale > 1.0
|
text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor,
|
||||||
|
step_index:int | None = None,
|
||||||
|
**extra_step_kwargs):
|
||||||
|
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||||
|
timestep = t[0]
|
||||||
|
|
||||||
# expand the latents if we are doing classifier free guidance
|
# TODO: should this scaling happen here or inside self._unet_forward?
|
||||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
noise_pred = self.invokeai_diffuser.do_diffusion_step(
|
||||||
|
latent_model_input, t,
|
||||||
# perform guidance
|
unconditioned_embeddings, text_embeddings,
|
||||||
if do_classifier_free_guidance:
|
guidance_scale,
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
step_index=step_index)
|
||||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
return self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
|
return self.scheduler.step(noise_pred, timestep, latents, **extra_step_kwargs)
|
||||||
|
|
||||||
|
def _unet_forward(self, latents, t, text_embeddings):
|
||||||
|
# predict the noise residual
|
||||||
|
return self.unet(latents, t, encoder_hidden_states=text_embeddings).sample
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def check_for_safety(self, output):
|
def check_for_safety(self, output):
|
||||||
@ -310,13 +300,10 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
|||||||
# opposing prompt defaults to blank caption for everything in the batch
|
# opposing prompt defaults to blank caption for everything in the batch
|
||||||
text_anti_input = self._tokenize(opposing_prompt or [""] * batch_size)
|
text_anti_input = self._tokenize(opposing_prompt or [""] * batch_size)
|
||||||
uncond_embeddings = self.text_encoder(text_anti_input.input_ids)[0]
|
uncond_embeddings = self.text_encoder(text_anti_input.input_ids)[0]
|
||||||
|
else:
|
||||||
|
uncond_embeddings = None
|
||||||
|
|
||||||
# For classifier free guidance, we need to do two forward passes.
|
return text_embeddings, uncond_embeddings
|
||||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
|
||||||
# to avoid doing two forward passes
|
|
||||||
# FIXME: assert these two are the same size
|
|
||||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
|
||||||
return text_embeddings
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):
|
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):
|
||||||
@ -325,6 +312,11 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
|||||||
"""
|
"""
|
||||||
return self.clip_embedder.encode(c, return_tokens=return_tokens, fragment_weights=fragment_weights)
|
return self.clip_embedder.encode(c, return_tokens=return_tokens, fragment_weights=fragment_weights)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cond_stage_model(self):
|
||||||
|
warnings.warn("legacy compatibility layer", DeprecationWarning)
|
||||||
|
return self.clip_embedder
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _tokenize(self, prompt: Union[str, List[str]]):
|
def _tokenize(self, prompt: Union[str, List[str]]):
|
||||||
return self.tokenizer(
|
return self.tokenizer(
|
||||||
|
@ -5,6 +5,7 @@ import PIL.Image
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .base import Generator
|
from .base import Generator
|
||||||
|
from .diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||||
|
|
||||||
|
|
||||||
class Txt2Img(Generator):
|
class Txt2Img(Generator):
|
||||||
@ -23,7 +24,8 @@ class Txt2Img(Generator):
|
|||||||
self.perlin = perlin
|
self.perlin = perlin
|
||||||
uc, c, extra_conditioning_info = conditioning
|
uc, c, extra_conditioning_info = conditioning
|
||||||
|
|
||||||
pipeline = self.model
|
# noinspection PyTypeChecker
|
||||||
|
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||||
pipeline.scheduler = sampler
|
pipeline.scheduler = sampler
|
||||||
|
|
||||||
def make_image(x_T) -> PIL.Image.Image:
|
def make_image(x_T) -> PIL.Image.Image:
|
||||||
@ -31,16 +33,14 @@ class Txt2Img(Generator):
|
|||||||
# if self.free_gpu_mem and self.model.model.device != self.model.device:
|
# if self.free_gpu_mem and self.model.model.device != self.model.device:
|
||||||
# self.model.model.to(self.model.device)
|
# self.model.model.to(self.model.device)
|
||||||
|
|
||||||
# FIXME: how the embeddings are combined should be internal to the pipeline
|
|
||||||
combined_text_embeddings = torch.cat([uc, c])
|
|
||||||
|
|
||||||
pipeline_output = pipeline.image_from_embeddings(
|
pipeline_output = pipeline.image_from_embeddings(
|
||||||
latents=x_T,
|
latents=x_T,
|
||||||
num_inference_steps=steps,
|
num_inference_steps=steps,
|
||||||
text_embeddings=combined_text_embeddings,
|
text_embeddings=c,
|
||||||
|
unconditioned_embeddings=uc,
|
||||||
guidance_scale=cfg_scale,
|
guidance_scale=cfg_scale,
|
||||||
callback=step_callback,
|
callback=step_callback,
|
||||||
# TODO: extra_conditioning_info = extra_conditioning_info,
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
# TODO: eta = ddim_eta,
|
# TODO: eta = ddim_eta,
|
||||||
# TODO: threshold = threshold,
|
# TODO: threshold = threshold,
|
||||||
)
|
)
|
||||||
|
@ -35,6 +35,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
:param model: the unet model to pass through to cross attention control
|
:param model: the unet model to pass through to cross attention control
|
||||||
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
||||||
"""
|
"""
|
||||||
|
self.conditioning = None
|
||||||
self.model = model
|
self.model = model
|
||||||
self.model_forward_callback = model_forward_callback
|
self.model_forward_callback = model_forward_callback
|
||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
|
Loading…
Reference in New Issue
Block a user