mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(diffusers): reduce some code duplication amongst the different tasks
This commit is contained in:
parent
d6eef612d7
commit
821c7df240
@ -157,6 +157,8 @@ ParamType = ParamSpec('ParamType')
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
||||
"""Convert a generator to a function with a callback and a return value."""
|
||||
|
||||
generator_method: Callable[ParamType, ReturnType]
|
||||
callback_arg_type: Type[CallbackType]
|
||||
|
||||
@ -261,111 +263,48 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
:param run_id:
|
||||
:param extra_step_kwargs:
|
||||
"""
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
|
||||
result = None
|
||||
for result in self.generate_from_embeddings(
|
||||
latents, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
run_id=run_id, **extra_step_kwargs):
|
||||
if callback is not None and isinstance(result, PipelineIntermediateState):
|
||||
callback(result)
|
||||
if result is None:
|
||||
raise AssertionError("why was that an empty generator?")
|
||||
return result
|
||||
|
||||
def latents_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
||||
text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
*, callback: Callable[[PipelineIntermediateState], None]=None,
|
||||
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None,
|
||||
run_id=None,
|
||||
**extra_step_kwargs) -> PipelineIntermediateState:
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
|
||||
f = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
|
||||
return f(latents, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
run_id=run_id,
|
||||
callback=callback,
|
||||
**extra_step_kwargs)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
*,
|
||||
opposing_prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 512,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
run_id: str = None,
|
||||
**extra_step_kwargs,
|
||||
):
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
else:
|
||||
batch_size = len(prompt)
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
combined_embeddings = self._encode_prompt(prompt, device=self._execution_device, num_images_per_prompt=1,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=opposing_prompt)
|
||||
text_embeddings, unconditioned_embeddings = combined_embeddings.chunk(2)
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
latents = self.prepare_latents(batch_size=batch_size, num_channels_latents=self.unet.in_channels,
|
||||
height=height, width=width,
|
||||
dtype=self.unet.dtype, device=self._execution_device,
|
||||
generator=generator,
|
||||
latents=latents)
|
||||
|
||||
yield from self.generate_from_embeddings(latents, text_embeddings, unconditioned_embeddings,
|
||||
guidance_scale, run_id=run_id, **extra_step_kwargs)
|
||||
|
||||
def generate_from_embeddings(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
text_embeddings: torch.Tensor,
|
||||
unconditioned_embeddings: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
*,
|
||||
run_id: str = None,
|
||||
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
|
||||
timesteps = None,
|
||||
additional_guidance: List[Callable] = None,
|
||||
**extra_step_kwargs):
|
||||
latents = yield from self.generate_latents_from_embeddings(latents, text_embeddings, unconditioned_embeddings,
|
||||
guidance_scale, run_id=run_id, extra_conditioning_info=extra_conditioning_info,
|
||||
timesteps=timesteps, additional_guidance=additional_guidance, **extra_step_kwargs)
|
||||
|
||||
result_latents = self.latents_from_embeddings(
|
||||
latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
run_id=run_id, callback=callback, **extra_step_kwargs
|
||||
)
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
image = self.decode_latents(latents)
|
||||
image = self.decode_latents(result_latents)
|
||||
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
|
||||
yield self.check_for_safety(output, dtype=text_embeddings.dtype)
|
||||
return self.check_for_safety(output, dtype=text_embeddings.dtype)
|
||||
|
||||
def generate_latents_from_embeddings(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
text_embeddings: torch.Tensor,
|
||||
unconditioned_embeddings: torch.Tensor,
|
||||
def latents_from_embeddings(
|
||||
self, latents: torch.Tensor, num_inference_steps: int,
|
||||
text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
*,
|
||||
run_id: str = None,
|
||||
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
|
||||
timesteps = None,
|
||||
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
|
||||
additional_guidance: List[Callable] = None,
|
||||
run_id=None,
|
||||
callback: Callable[[PipelineIntermediateState], None]=None,
|
||||
**extra_step_kwargs
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
if timesteps is None:
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
|
||||
return infer_latents_from_embeddings(
|
||||
latents, timesteps, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
additional_guidance=additional_guidance,
|
||||
run_id=run_id,
|
||||
callback=callback,
|
||||
**extra_step_kwargs).latents
|
||||
|
||||
def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps, text_embeddings: torch.Tensor,
|
||||
unconditioned_embeddings: torch.Tensor, guidance_scale: float, *,
|
||||
run_id: str = None,
|
||||
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
|
||||
additional_guidance: List[Callable] = None, **extra_step_kwargs):
|
||||
if run_id is None:
|
||||
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
||||
if additional_guidance is None:
|
||||
@ -375,9 +314,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
step_count=len(self.scheduler.timesteps))
|
||||
else:
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
if timesteps is None:
|
||||
# NOTE: Depends on scheduler being already initialized!
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents *= self.scheduler.init_noise_sigma
|
||||
@ -448,8 +384,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
run_id=None,
|
||||
noise_func=None,
|
||||
**extra_step_kwargs) -> StableDiffusionPipelineOutput:
|
||||
device = self.unet.device
|
||||
latents_dtype = self.unet.dtype
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
|
||||
|
||||
@ -457,13 +391,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')
|
||||
|
||||
# 6. Prepare latent variables
|
||||
device = self.unet.device
|
||||
latents_dtype = self.unet.dtype
|
||||
initial_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
|
||||
|
||||
result = self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps, text_embeddings,
|
||||
return self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps, text_embeddings,
|
||||
unconditioned_embeddings, guidance_scale, strength,
|
||||
extra_conditioning_info, noise_func, run_id, callback,
|
||||
**extra_step_kwargs)
|
||||
return result
|
||||
|
||||
def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_steps, text_embeddings,
|
||||
unconditioned_embeddings, guidance_scale, strength, extra_conditioning_info,
|
||||
@ -476,13 +411,21 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
latent_timestep = timesteps[:1].repeat(batch_size)
|
||||
latents = self.noise_latents_for_time(initial_latents, latent_timestep, noise_func=noise_func)
|
||||
|
||||
f = GeneratorToCallbackinator(self.generate_from_embeddings, PipelineIntermediateState)
|
||||
return f(latents, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||
result_latents = self.latents_from_embeddings(
|
||||
latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
timesteps=timesteps,
|
||||
callback=callback,
|
||||
run_id=run_id, **extra_step_kwargs)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
image = self.decode_latents(result_latents)
|
||||
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
|
||||
return self.check_for_safety(output, dtype=text_embeddings.dtype)
|
||||
|
||||
def inpaint_from_embeddings(
|
||||
self,
|
||||
init_image: torch.FloatTensor,
|
||||
@ -536,22 +479,25 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
else:
|
||||
guidance.append(AddsMaskGuidance(mask, init_image_latents, self.scheduler, noise_func))
|
||||
|
||||
result = None
|
||||
|
||||
try:
|
||||
for result in self.generate_from_embeddings(
|
||||
latents, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
timesteps=timesteps,
|
||||
run_id=run_id, additional_guidance=guidance, **extra_step_kwargs):
|
||||
if callback is not None and isinstance(result, PipelineIntermediateState):
|
||||
callback(result)
|
||||
if result is None:
|
||||
raise AssertionError("why was that an empty generator?")
|
||||
return result
|
||||
result_latents = self.latents_from_embeddings(
|
||||
latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
timesteps=timesteps,
|
||||
run_id=run_id, additional_guidance=guidance,
|
||||
callback=callback,
|
||||
**extra_step_kwargs)
|
||||
finally:
|
||||
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
image = self.decode_latents(result_latents)
|
||||
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
|
||||
return self.check_for_safety(output, dtype=text_embeddings.dtype)
|
||||
|
||||
def non_noised_latents_from_image(self, init_image, *, device, dtype):
|
||||
init_image = init_image.to(device=device, dtype=dtype)
|
||||
with torch.inference_mode():
|
||||
|
Loading…
Reference in New Issue
Block a user