mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 05:17:55 +00:00
ruff
This commit is contained in:
committed by
psychedelicious
parent
ca1df60e54
commit
51e1c56636
@ -11,7 +11,12 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_t
|
||||
|
||||
|
||||
def generate_img_ids_with_offset(
|
||||
latent_height: int, latent_width: int, batch_size: int, device: torch.device, dtype: torch.dtype, idx_offset: int = 0
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
batch_size: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
idx_offset: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Generate tensor of image position ids with an optional offset.
|
||||
|
||||
@ -34,24 +39,24 @@ def generate_img_ids_with_offset(
|
||||
# After packing, the spatial dimensions are halved due to the 2x2 patch structure
|
||||
packed_height = latent_height // 2
|
||||
packed_width = latent_width // 2
|
||||
|
||||
|
||||
# Create base tensor for position IDs with shape [packed_height, packed_width, 3]
|
||||
# The 3 channels represent: [batch_offset, y_position, x_position]
|
||||
img_ids = torch.zeros(packed_height, packed_width, 3, device=device, dtype=dtype)
|
||||
|
||||
|
||||
# Set the batch offset for all positions
|
||||
img_ids[..., 0] = idx_offset
|
||||
|
||||
|
||||
# Create y-coordinate indices (vertical positions)
|
||||
y_indices = torch.arange(packed_height, device=device, dtype=dtype)
|
||||
# Broadcast y_indices to match the spatial dimensions [packed_height, 1]
|
||||
img_ids[..., 1] = y_indices[:, None]
|
||||
|
||||
# Create x-coordinate indices (horizontal positions)
|
||||
|
||||
# Create x-coordinate indices (horizontal positions)
|
||||
x_indices = torch.arange(packed_width, device=device, dtype=dtype)
|
||||
# Broadcast x_indices to match the spatial dimensions [1, packed_width]
|
||||
img_ids[..., 2] = x_indices[None, :]
|
||||
|
||||
|
||||
# Expand to include batch dimension: [batch_size, (packed_height * packed_width), 3]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
||||
|
||||
|
Reference in New Issue
Block a user