Tidy variable management and dtype handling in FluxTextToImageInvocation.

This commit is contained in:
Ryan Dick 2024-08-28 15:03:08 +00:00
parent 5e8cf9fb6a
commit 4e4b6c6dbc
3 changed files with 24 additions and 26 deletions

View File

@ -58,13 +58,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
# Load the conditioning data. latents = self._run_diffusion(context)
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
image = self._run_vae_decoding(context, latents) image = self._run_vae_decoding(context, latents)
image_dto = context.images.save(image=image) image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto) return ImageOutput.build(image_dto)
@ -72,12 +66,20 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def _run_diffusion( def _run_diffusion(
self, self,
context: InvocationContext, context: InvocationContext,
clip_embeddings: torch.Tensor,
t5_embeddings: torch.Tensor,
): ):
transformer_info = context.models.load(self.transformer.transformer)
inference_dtype = torch.bfloat16 inference_dtype = torch.bfloat16
# Load the conditioning data.
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
t5_embeddings = flux_conditioning.t5_embeds
clip_embeddings = flux_conditioning.clip_embeds
transformer_info = context.models.load(self.transformer.transformer)
# Prepare input noise. # Prepare input noise.
x = get_noise( x = get_noise(
num_samples=1, num_samples=1,
@ -88,13 +90,13 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
seed=self.seed, seed=self.seed,
) )
img, img_ids = prepare_latent_img_patches(x) x, img_ids = prepare_latent_img_patches(x)
is_schnell = "schnell" in transformer_info.config.config_path is_schnell = "schnell" in transformer_info.config.config_path
timesteps = get_schedule( timesteps = get_schedule(
num_steps=self.num_steps, num_steps=self.num_steps,
image_seq_len=img.shape[1], image_seq_len=x.shape[1],
shift=not is_schnell, shift=not is_schnell,
) )
@ -135,7 +137,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
x = denoise( x = denoise(
model=transformer, model=transformer,
img=img, img=x,
img_ids=img_ids, img_ids=img_ids,
txt=t5_embeddings, txt=t5_embeddings,
txt_ids=txt_ids, txt_ids=txt_ids,

View File

@ -111,16 +111,7 @@ def denoise(
step_callback: Callable[[], None], step_callback: Callable[[], None],
guidance: float = 4.0, guidance: float = 4.0,
): ):
dtype = model.txt_in.bias.dtype # guidance_vec is ignored for schnell.
# TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller.
img = img.to(dtype=dtype)
img_ids = img_ids.to(dtype=dtype)
txt = txt.to(dtype=dtype)
txt_ids = txt_ids.to(dtype=dtype)
vec = vec.to(dtype=dtype)
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))): for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
@ -168,9 +159,9 @@ def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor,
img = repeat(img, "1 ... -> bs ...", bs=bs) img = repeat(img, "1 ... -> bs ...", bs=bs)
# Generate patch position ids. # Generate patch position ids.
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device) img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None] img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device, dtype=img.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :] img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device, dtype=img.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
return img, img_ids return img, img_ids

View File

@ -43,6 +43,11 @@ class FLUXConditioningInfo:
clip_embeds: torch.Tensor clip_embeds: torch.Tensor
t5_embeds: torch.Tensor t5_embeds: torch.Tensor
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self.clip_embeds = self.clip_embeds.to(device=device, dtype=dtype)
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
return self
@dataclass @dataclass
class ConditioningFieldData: class ConditioningFieldData: