mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(diffusers): reduce some code duplication amongst the different tasks
This commit is contained in:
parent
d6eef612d7
commit
821c7df240
@ -157,6 +157,8 @@ ParamType = ParamSpec('ParamType')
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
||||||
|
"""Convert a generator to a function with a callback and a return value."""
|
||||||
|
|
||||||
generator_method: Callable[ParamType, ReturnType]
|
generator_method: Callable[ParamType, ReturnType]
|
||||||
callback_arg_type: Type[CallbackType]
|
callback_arg_type: Type[CallbackType]
|
||||||
|
|
||||||
@ -261,111 +263,48 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
:param run_id:
|
:param run_id:
|
||||||
:param extra_step_kwargs:
|
:param extra_step_kwargs:
|
||||||
"""
|
"""
|
||||||
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
|
result_latents = self.latents_from_embeddings(
|
||||||
result = None
|
latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||||
for result in self.generate_from_embeddings(
|
|
||||||
latents, text_embeddings, unconditioned_embeddings, guidance_scale,
|
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
run_id=run_id, **extra_step_kwargs):
|
run_id=run_id, callback=callback, **extra_step_kwargs
|
||||||
if callback is not None and isinstance(result, PipelineIntermediateState):
|
)
|
||||||
callback(result)
|
|
||||||
if result is None:
|
|
||||||
raise AssertionError("why was that an empty generator?")
|
|
||||||
return result
|
|
||||||
|
|
||||||
def latents_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
|
||||||
text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor,
|
|
||||||
guidance_scale: float,
|
|
||||||
*, callback: Callable[[PipelineIntermediateState], None]=None,
|
|
||||||
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None,
|
|
||||||
run_id=None,
|
|
||||||
**extra_step_kwargs) -> PipelineIntermediateState:
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
|
|
||||||
f = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
|
|
||||||
return f(latents, text_embeddings, unconditioned_embeddings, guidance_scale,
|
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
|
||||||
run_id=run_id,
|
|
||||||
callback=callback,
|
|
||||||
**extra_step_kwargs)
|
|
||||||
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
prompt: Union[str, List[str]],
|
|
||||||
*,
|
|
||||||
opposing_prompt: Union[str, List[str]] = None,
|
|
||||||
height: Optional[int] = 512,
|
|
||||||
width: Optional[int] = 512,
|
|
||||||
num_inference_steps: Optional[int] = 50,
|
|
||||||
guidance_scale: Optional[float] = 7.5,
|
|
||||||
generator: Optional[torch.Generator] = None,
|
|
||||||
latents: Optional[torch.FloatTensor] = None,
|
|
||||||
run_id: str = None,
|
|
||||||
**extra_step_kwargs,
|
|
||||||
):
|
|
||||||
if isinstance(prompt, str):
|
|
||||||
batch_size = 1
|
|
||||||
else:
|
|
||||||
batch_size = len(prompt)
|
|
||||||
|
|
||||||
if height % 8 != 0 or width % 8 != 0:
|
|
||||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
|
||||||
|
|
||||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
|
||||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
|
||||||
# corresponds to doing no classifier free guidance.
|
|
||||||
do_classifier_free_guidance = guidance_scale > 1.0
|
|
||||||
|
|
||||||
combined_embeddings = self._encode_prompt(prompt, device=self._execution_device, num_images_per_prompt=1,
|
|
||||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
|
||||||
negative_prompt=opposing_prompt)
|
|
||||||
text_embeddings, unconditioned_embeddings = combined_embeddings.chunk(2)
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps)
|
|
||||||
latents = self.prepare_latents(batch_size=batch_size, num_channels_latents=self.unet.in_channels,
|
|
||||||
height=height, width=width,
|
|
||||||
dtype=self.unet.dtype, device=self._execution_device,
|
|
||||||
generator=generator,
|
|
||||||
latents=latents)
|
|
||||||
|
|
||||||
yield from self.generate_from_embeddings(latents, text_embeddings, unconditioned_embeddings,
|
|
||||||
guidance_scale, run_id=run_id, **extra_step_kwargs)
|
|
||||||
|
|
||||||
def generate_from_embeddings(
|
|
||||||
self,
|
|
||||||
latents: torch.Tensor,
|
|
||||||
text_embeddings: torch.Tensor,
|
|
||||||
unconditioned_embeddings: torch.Tensor,
|
|
||||||
guidance_scale: float,
|
|
||||||
*,
|
|
||||||
run_id: str = None,
|
|
||||||
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
|
|
||||||
timesteps = None,
|
|
||||||
additional_guidance: List[Callable] = None,
|
|
||||||
**extra_step_kwargs):
|
|
||||||
latents = yield from self.generate_latents_from_embeddings(latents, text_embeddings, unconditioned_embeddings,
|
|
||||||
guidance_scale, run_id=run_id, extra_conditioning_info=extra_conditioning_info,
|
|
||||||
timesteps=timesteps, additional_guidance=additional_guidance, **extra_step_kwargs)
|
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
image = self.decode_latents(latents)
|
image = self.decode_latents(result_latents)
|
||||||
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
|
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
|
||||||
yield self.check_for_safety(output, dtype=text_embeddings.dtype)
|
return self.check_for_safety(output, dtype=text_embeddings.dtype)
|
||||||
|
|
||||||
def generate_latents_from_embeddings(
|
def latents_from_embeddings(
|
||||||
self,
|
self, latents: torch.Tensor, num_inference_steps: int,
|
||||||
latents: torch.Tensor,
|
text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor,
|
||||||
text_embeddings: torch.Tensor,
|
|
||||||
unconditioned_embeddings: torch.Tensor,
|
|
||||||
guidance_scale: float,
|
guidance_scale: float,
|
||||||
*,
|
*,
|
||||||
|
timesteps = None,
|
||||||
|
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
|
||||||
|
additional_guidance: List[Callable] = None,
|
||||||
|
run_id=None,
|
||||||
|
callback: Callable[[PipelineIntermediateState], None]=None,
|
||||||
|
**extra_step_kwargs
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if timesteps is None:
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
|
||||||
|
timesteps = self.scheduler.timesteps
|
||||||
|
infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
|
||||||
|
return infer_latents_from_embeddings(
|
||||||
|
latents, timesteps, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||||
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
|
additional_guidance=additional_guidance,
|
||||||
|
run_id=run_id,
|
||||||
|
callback=callback,
|
||||||
|
**extra_step_kwargs).latents
|
||||||
|
|
||||||
|
def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps, text_embeddings: torch.Tensor,
|
||||||
|
unconditioned_embeddings: torch.Tensor, guidance_scale: float, *,
|
||||||
run_id: str = None,
|
run_id: str = None,
|
||||||
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
|
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
|
||||||
timesteps = None,
|
additional_guidance: List[Callable] = None, **extra_step_kwargs):
|
||||||
additional_guidance: List[Callable] = None,
|
|
||||||
**extra_step_kwargs
|
|
||||||
):
|
|
||||||
if run_id is None:
|
if run_id is None:
|
||||||
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
||||||
if additional_guidance is None:
|
if additional_guidance is None:
|
||||||
@ -375,9 +314,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
step_count=len(self.scheduler.timesteps))
|
step_count=len(self.scheduler.timesteps))
|
||||||
else:
|
else:
|
||||||
self.invokeai_diffuser.remove_cross_attention_control()
|
self.invokeai_diffuser.remove_cross_attention_control()
|
||||||
if timesteps is None:
|
|
||||||
# NOTE: Depends on scheduler being already initialized!
|
|
||||||
timesteps = self.scheduler.timesteps
|
|
||||||
|
|
||||||
# scale the initial noise by the standard deviation required by the scheduler
|
# scale the initial noise by the standard deviation required by the scheduler
|
||||||
latents *= self.scheduler.init_noise_sigma
|
latents *= self.scheduler.init_noise_sigma
|
||||||
@ -448,8 +384,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
run_id=None,
|
run_id=None,
|
||||||
noise_func=None,
|
noise_func=None,
|
||||||
**extra_step_kwargs) -> StableDiffusionPipelineOutput:
|
**extra_step_kwargs) -> StableDiffusionPipelineOutput:
|
||||||
device = self.unet.device
|
|
||||||
latents_dtype = self.unet.dtype
|
|
||||||
if isinstance(init_image, PIL.Image.Image):
|
if isinstance(init_image, PIL.Image.Image):
|
||||||
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
|
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
|
||||||
|
|
||||||
@ -457,13 +391,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')
|
init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')
|
||||||
|
|
||||||
# 6. Prepare latent variables
|
# 6. Prepare latent variables
|
||||||
|
device = self.unet.device
|
||||||
|
latents_dtype = self.unet.dtype
|
||||||
initial_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
|
initial_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
|
||||||
|
|
||||||
result = self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps, text_embeddings,
|
return self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps, text_embeddings,
|
||||||
unconditioned_embeddings, guidance_scale, strength,
|
unconditioned_embeddings, guidance_scale, strength,
|
||||||
extra_conditioning_info, noise_func, run_id, callback,
|
extra_conditioning_info, noise_func, run_id, callback,
|
||||||
**extra_step_kwargs)
|
**extra_step_kwargs)
|
||||||
return result
|
|
||||||
|
|
||||||
def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_steps, text_embeddings,
|
def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_steps, text_embeddings,
|
||||||
unconditioned_embeddings, guidance_scale, strength, extra_conditioning_info,
|
unconditioned_embeddings, guidance_scale, strength, extra_conditioning_info,
|
||||||
@ -476,13 +411,21 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
latent_timestep = timesteps[:1].repeat(batch_size)
|
latent_timestep = timesteps[:1].repeat(batch_size)
|
||||||
latents = self.noise_latents_for_time(initial_latents, latent_timestep, noise_func=noise_func)
|
latents = self.noise_latents_for_time(initial_latents, latent_timestep, noise_func=noise_func)
|
||||||
|
|
||||||
f = GeneratorToCallbackinator(self.generate_from_embeddings, PipelineIntermediateState)
|
result_latents = self.latents_from_embeddings(
|
||||||
return f(latents, text_embeddings, unconditioned_embeddings, guidance_scale,
|
latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
run_id=run_id, **extra_step_kwargs)
|
run_id=run_id, **extra_step_kwargs)
|
||||||
|
|
||||||
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
image = self.decode_latents(result_latents)
|
||||||
|
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
|
||||||
|
return self.check_for_safety(output, dtype=text_embeddings.dtype)
|
||||||
|
|
||||||
def inpaint_from_embeddings(
|
def inpaint_from_embeddings(
|
||||||
self,
|
self,
|
||||||
init_image: torch.FloatTensor,
|
init_image: torch.FloatTensor,
|
||||||
@ -536,22 +479,25 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
else:
|
else:
|
||||||
guidance.append(AddsMaskGuidance(mask, init_image_latents, self.scheduler, noise_func))
|
guidance.append(AddsMaskGuidance(mask, init_image_latents, self.scheduler, noise_func))
|
||||||
|
|
||||||
result = None
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for result in self.generate_from_embeddings(
|
result_latents = self.latents_from_embeddings(
|
||||||
latents, text_embeddings, unconditioned_embeddings, guidance_scale,
|
latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
run_id=run_id, additional_guidance=guidance, **extra_step_kwargs):
|
run_id=run_id, additional_guidance=guidance,
|
||||||
if callback is not None and isinstance(result, PipelineIntermediateState):
|
callback=callback,
|
||||||
callback(result)
|
**extra_step_kwargs)
|
||||||
if result is None:
|
|
||||||
raise AssertionError("why was that an empty generator?")
|
|
||||||
return result
|
|
||||||
finally:
|
finally:
|
||||||
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
||||||
|
|
||||||
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
image = self.decode_latents(result_latents)
|
||||||
|
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
|
||||||
|
return self.check_for_safety(output, dtype=text_embeddings.dtype)
|
||||||
|
|
||||||
def non_noised_latents_from_image(self, init_image, *, device, dtype):
|
def non_noised_latents_from_image(self, init_image, *, device, dtype):
|
||||||
init_image = init_image.to(device=device, dtype=dtype)
|
init_image = init_image.to(device=device, dtype=dtype)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user