Explain the Magic

This commit is contained in:
Kent Keirsey
2025-06-26 23:57:04 -04:00
committed by psychedelicious
parent 7549c1250d
commit ca1df60e54
2 changed files with 36 additions and 17 deletions

View File

@ -384,7 +384,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
dtype=inference_dtype,
)
# Instantiate our new extension if the conditioning is provided
kontext_extension = None
if self.kontext_conditioning is not None:
# We need a VAE to encode the reference image. We can reuse the
@ -400,7 +399,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
dtype=inference_dtype,
)
# THE CRITICAL INTEGRATION POINT
final_img, final_img_ids = x, img_ids
original_seq_len = x.shape[1] # Store the original sequence length
if kontext_extension is not None:
@ -426,7 +424,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
img_cond=img_cond,
)
# Extract only the main image tokens if kontext was applied
if kontext_extension is not None:
x = x[:, :original_seq_len, :] # Keep only the first original_seq_len tokens

View File

@ -11,30 +11,48 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_t
def generate_img_ids_with_offset(
h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype, idx_offset: int = 0
latent_height: int, latent_width: int, batch_size: int, device: torch.device, dtype: torch.dtype, idx_offset: int = 0
) -> torch.Tensor:
"""Generate tensor of image position ids with an optional offset.
Args:
h (int): Height of image in latent space.
w (int): Width of image in latent space.
batch_size (int): Batch size.
device (torch.device): Device.
dtype (torch.dtype): dtype.
latent_height (int): Height of image in latent space (after packing, this becomes h//2).
latent_width (int): Width of image in latent space (after packing, this becomes w//2).
batch_size (int): Number of images in the batch.
device (torch.device): Device to create tensors on.
dtype (torch.dtype): Data type for the tensors.
idx_offset (int): Offset to add to the first dimension of the image ids.
Returns:
torch.Tensor: Image position ids.
torch.Tensor: Image position ids with shape [batch_size, (latent_height//2 * latent_width//2), 3].
"""
if device.type == "mps":
orig_dtype = dtype
dtype = torch.float16
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
img_ids[..., 0] = idx_offset # Set the offset for the first dimension
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
# After packing, the spatial dimensions are halved due to the 2x2 patch structure
packed_height = latent_height // 2
packed_width = latent_width // 2
# Create base tensor for position IDs with shape [packed_height, packed_width, 3]
# The 3 channels represent: [batch_offset, y_position, x_position]
img_ids = torch.zeros(packed_height, packed_width, 3, device=device, dtype=dtype)
# Set the batch offset for all positions
img_ids[..., 0] = idx_offset
# Create y-coordinate indices (vertical positions)
y_indices = torch.arange(packed_height, device=device, dtype=dtype)
# Broadcast y_indices to match the spatial dimensions [packed_height, 1]
img_ids[..., 1] = y_indices[:, None]
# Create x-coordinate indices (horizontal positions)
x_indices = torch.arange(packed_width, device=device, dtype=dtype)
# Broadcast x_indices to match the spatial dimensions [1, packed_width]
img_ids[..., 2] = x_indices[None, :]
# Expand to include batch dimension: [batch_size, (packed_height * packed_width), 3]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
if device.type == "mps":
@ -80,13 +98,17 @@ class KontextExtension:
kontext_latents_unpacked = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
# Extract tensor dimensions with descriptive names
# Latent tensor shape: [batch_size, channels, latent_height, latent_width]
batch_size, _, latent_height, latent_width = kontext_latents_unpacked.shape
# Pack the latents and generate IDs. The idx_offset distinguishes these
# tokens from the main image's tokens, which have an index of 0.
kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype)
kontext_ids = generate_img_ids_with_offset(
h=kontext_latents_unpacked.shape[2],
w=kontext_latents_unpacked.shape[3],
batch_size=kontext_latents_unpacked.shape[0],
latent_height=latent_height,
latent_width=latent_width,
batch_size=batch_size,
device=self._device,
dtype=self._dtype,
idx_offset=1, # Distinguishes reference tokens from main image tokens