refactor(diffusers): reduce some code duplication amongst the different tasks

This commit is contained in:
Kevin Turner 2022-12-07 18:20:56 -08:00
parent d6eef612d7
commit 821c7df240

View File

@ -157,6 +157,8 @@ ParamType = ParamSpec('ParamType')
@dataclass(frozen=True)
class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
"""Convert a generator to a function with a callback and a return value."""
generator_method: Callable[ParamType, ReturnType]
callback_arg_type: Type[CallbackType]
@ -261,111 +263,48 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
:param run_id:
:param extra_step_kwargs:
"""
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
result = None
for result in self.generate_from_embeddings(
latents, text_embeddings, unconditioned_embeddings, guidance_scale,
extra_conditioning_info=extra_conditioning_info,
run_id=run_id, **extra_step_kwargs):
if callback is not None and isinstance(result, PipelineIntermediateState):
callback(result)
if result is None:
raise AssertionError("why was that an empty generator?")
return result
def latents_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor,
guidance_scale: float,
*, callback: Callable[[PipelineIntermediateState], None]=None,
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None,
run_id=None,
**extra_step_kwargs) -> PipelineIntermediateState:
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
f = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
return f(latents, text_embeddings, unconditioned_embeddings, guidance_scale,
extra_conditioning_info=extra_conditioning_info,
run_id=run_id,
callback=callback,
**extra_step_kwargs)
def generate(
self,
prompt: Union[str, List[str]],
*,
opposing_prompt: Union[str, List[str]] = None,
height: Optional[int] = 512,
width: Optional[int] = 512,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
run_id: str = None,
**extra_step_kwargs,
):
if isinstance(prompt, str):
batch_size = 1
else:
batch_size = len(prompt)
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
combined_embeddings = self._encode_prompt(prompt, device=self._execution_device, num_images_per_prompt=1,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=opposing_prompt)
text_embeddings, unconditioned_embeddings = combined_embeddings.chunk(2)
self.scheduler.set_timesteps(num_inference_steps)
latents = self.prepare_latents(batch_size=batch_size, num_channels_latents=self.unet.in_channels,
height=height, width=width,
dtype=self.unet.dtype, device=self._execution_device,
generator=generator,
latents=latents)
yield from self.generate_from_embeddings(latents, text_embeddings, unconditioned_embeddings,
guidance_scale, run_id=run_id, **extra_step_kwargs)
def generate_from_embeddings(
self,
latents: torch.Tensor,
text_embeddings: torch.Tensor,
unconditioned_embeddings: torch.Tensor,
guidance_scale: float,
*,
run_id: str = None,
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
timesteps = None,
additional_guidance: List[Callable] = None,
**extra_step_kwargs):
latents = yield from self.generate_latents_from_embeddings(latents, text_embeddings, unconditioned_embeddings,
guidance_scale, run_id=run_id, extra_conditioning_info=extra_conditioning_info,
timesteps=timesteps, additional_guidance=additional_guidance, **extra_step_kwargs)
result_latents = self.latents_from_embeddings(
latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale,
extra_conditioning_info=extra_conditioning_info,
run_id=run_id, callback=callback, **extra_step_kwargs
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
with torch.inference_mode():
image = self.decode_latents(latents)
image = self.decode_latents(result_latents)
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
yield self.check_for_safety(output, dtype=text_embeddings.dtype)
return self.check_for_safety(output, dtype=text_embeddings.dtype)
def generate_latents_from_embeddings(
self,
latents: torch.Tensor,
text_embeddings: torch.Tensor,
unconditioned_embeddings: torch.Tensor,
def latents_from_embeddings(
self, latents: torch.Tensor, num_inference_steps: int,
text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor,
guidance_scale: float,
*,
run_id: str = None,
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
timesteps = None,
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
additional_guidance: List[Callable] = None,
run_id=None,
callback: Callable[[PipelineIntermediateState], None]=None,
**extra_step_kwargs
):
) -> torch.Tensor:
if timesteps is None:
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
timesteps = self.scheduler.timesteps
infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
return infer_latents_from_embeddings(
latents, timesteps, text_embeddings, unconditioned_embeddings, guidance_scale,
extra_conditioning_info=extra_conditioning_info,
additional_guidance=additional_guidance,
run_id=run_id,
callback=callback,
**extra_step_kwargs).latents
def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps, text_embeddings: torch.Tensor,
unconditioned_embeddings: torch.Tensor, guidance_scale: float, *,
run_id: str = None,
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
additional_guidance: List[Callable] = None, **extra_step_kwargs):
if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH)
if additional_guidance is None:
@ -375,9 +314,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_count=len(self.scheduler.timesteps))
else:
self.invokeai_diffuser.remove_cross_attention_control()
if timesteps is None:
# NOTE: Depends on scheduler being already initialized!
timesteps = self.scheduler.timesteps
# scale the initial noise by the standard deviation required by the scheduler
latents *= self.scheduler.init_noise_sigma
@ -448,8 +384,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
run_id=None,
noise_func=None,
**extra_step_kwargs) -> StableDiffusionPipelineOutput:
device = self.unet.device
latents_dtype = self.unet.dtype
if isinstance(init_image, PIL.Image.Image):
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
@ -457,13 +391,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')
# 6. Prepare latent variables
device = self.unet.device
latents_dtype = self.unet.dtype
initial_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
result = self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps, text_embeddings,
return self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps, text_embeddings,
unconditioned_embeddings, guidance_scale, strength,
extra_conditioning_info, noise_func, run_id, callback,
**extra_step_kwargs)
return result
def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_steps, text_embeddings,
unconditioned_embeddings, guidance_scale, strength, extra_conditioning_info,
@ -476,13 +411,21 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latent_timestep = timesteps[:1].repeat(batch_size)
latents = self.noise_latents_for_time(initial_latents, latent_timestep, noise_func=noise_func)
f = GeneratorToCallbackinator(self.generate_from_embeddings, PipelineIntermediateState)
return f(latents, text_embeddings, unconditioned_embeddings, guidance_scale,
result_latents = self.latents_from_embeddings(
latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale,
extra_conditioning_info=extra_conditioning_info,
timesteps=timesteps,
callback=callback,
run_id=run_id, **extra_step_kwargs)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
with torch.inference_mode():
image = self.decode_latents(result_latents)
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
return self.check_for_safety(output, dtype=text_embeddings.dtype)
def inpaint_from_embeddings(
self,
init_image: torch.FloatTensor,
@ -536,22 +479,25 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
else:
guidance.append(AddsMaskGuidance(mask, init_image_latents, self.scheduler, noise_func))
result = None
try:
for result in self.generate_from_embeddings(
latents, text_embeddings, unconditioned_embeddings, guidance_scale,
extra_conditioning_info=extra_conditioning_info,
timesteps=timesteps,
run_id=run_id, additional_guidance=guidance, **extra_step_kwargs):
if callback is not None and isinstance(result, PipelineIntermediateState):
callback(result)
if result is None:
raise AssertionError("why was that an empty generator?")
return result
result_latents = self.latents_from_embeddings(
latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale,
extra_conditioning_info=extra_conditioning_info,
timesteps=timesteps,
run_id=run_id, additional_guidance=guidance,
callback=callback,
**extra_step_kwargs)
finally:
self.invokeai_diffuser.model_forward_callback = self._unet_forward
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
with torch.inference_mode():
image = self.decode_latents(result_latents)
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
return self.check_for_safety(output, dtype=text_embeddings.dtype)
def non_noised_latents_from_image(self, init_image, *, device, dtype):
init_image = init_image.to(device=device, dtype=dtype)
with torch.inference_mode():