mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 05:17:55 +00:00
Explain the Magic
This commit is contained in:
committed by
psychedelicious
parent
7549c1250d
commit
ca1df60e54
@ -384,7 +384,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
dtype=inference_dtype,
|
||||
)
|
||||
|
||||
# Instantiate our new extension if the conditioning is provided
|
||||
kontext_extension = None
|
||||
if self.kontext_conditioning is not None:
|
||||
# We need a VAE to encode the reference image. We can reuse the
|
||||
@ -400,7 +399,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
dtype=inference_dtype,
|
||||
)
|
||||
|
||||
# THE CRITICAL INTEGRATION POINT
|
||||
final_img, final_img_ids = x, img_ids
|
||||
original_seq_len = x.shape[1] # Store the original sequence length
|
||||
if kontext_extension is not None:
|
||||
@ -426,7 +424,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
img_cond=img_cond,
|
||||
)
|
||||
|
||||
# Extract only the main image tokens if kontext was applied
|
||||
if kontext_extension is not None:
|
||||
x = x[:, :original_seq_len, :] # Keep only the first original_seq_len tokens
|
||||
|
||||
|
@ -11,30 +11,48 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_t
|
||||
|
||||
|
||||
def generate_img_ids_with_offset(
|
||||
h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype, idx_offset: int = 0
|
||||
latent_height: int, latent_width: int, batch_size: int, device: torch.device, dtype: torch.dtype, idx_offset: int = 0
|
||||
) -> torch.Tensor:
|
||||
"""Generate tensor of image position ids with an optional offset.
|
||||
|
||||
Args:
|
||||
h (int): Height of image in latent space.
|
||||
w (int): Width of image in latent space.
|
||||
batch_size (int): Batch size.
|
||||
device (torch.device): Device.
|
||||
dtype (torch.dtype): dtype.
|
||||
latent_height (int): Height of image in latent space (after packing, this becomes h//2).
|
||||
latent_width (int): Width of image in latent space (after packing, this becomes w//2).
|
||||
batch_size (int): Number of images in the batch.
|
||||
device (torch.device): Device to create tensors on.
|
||||
dtype (torch.dtype): Data type for the tensors.
|
||||
idx_offset (int): Offset to add to the first dimension of the image ids.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Image position ids.
|
||||
torch.Tensor: Image position ids with shape [batch_size, (latent_height//2 * latent_width//2), 3].
|
||||
"""
|
||||
|
||||
if device.type == "mps":
|
||||
orig_dtype = dtype
|
||||
dtype = torch.float16
|
||||
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
|
||||
img_ids[..., 0] = idx_offset # Set the offset for the first dimension
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
|
||||
# After packing, the spatial dimensions are halved due to the 2x2 patch structure
|
||||
packed_height = latent_height // 2
|
||||
packed_width = latent_width // 2
|
||||
|
||||
# Create base tensor for position IDs with shape [packed_height, packed_width, 3]
|
||||
# The 3 channels represent: [batch_offset, y_position, x_position]
|
||||
img_ids = torch.zeros(packed_height, packed_width, 3, device=device, dtype=dtype)
|
||||
|
||||
# Set the batch offset for all positions
|
||||
img_ids[..., 0] = idx_offset
|
||||
|
||||
# Create y-coordinate indices (vertical positions)
|
||||
y_indices = torch.arange(packed_height, device=device, dtype=dtype)
|
||||
# Broadcast y_indices to match the spatial dimensions [packed_height, 1]
|
||||
img_ids[..., 1] = y_indices[:, None]
|
||||
|
||||
# Create x-coordinate indices (horizontal positions)
|
||||
x_indices = torch.arange(packed_width, device=device, dtype=dtype)
|
||||
# Broadcast x_indices to match the spatial dimensions [1, packed_width]
|
||||
img_ids[..., 2] = x_indices[None, :]
|
||||
|
||||
# Expand to include batch dimension: [batch_size, (packed_height * packed_width), 3]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
||||
|
||||
if device.type == "mps":
|
||||
@ -80,13 +98,17 @@ class KontextExtension:
|
||||
|
||||
kontext_latents_unpacked = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
# Extract tensor dimensions with descriptive names
|
||||
# Latent tensor shape: [batch_size, channels, latent_height, latent_width]
|
||||
batch_size, _, latent_height, latent_width = kontext_latents_unpacked.shape
|
||||
|
||||
# Pack the latents and generate IDs. The idx_offset distinguishes these
|
||||
# tokens from the main image's tokens, which have an index of 0.
|
||||
kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype)
|
||||
kontext_ids = generate_img_ids_with_offset(
|
||||
h=kontext_latents_unpacked.shape[2],
|
||||
w=kontext_latents_unpacked.shape[3],
|
||||
batch_size=kontext_latents_unpacked.shape[0],
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
batch_size=batch_size,
|
||||
device=self._device,
|
||||
dtype=self._dtype,
|
||||
idx_offset=1, # Distinguishes reference tokens from main image tokens
|
||||
|
Reference in New Issue
Block a user