This commit is contained in:
Kent Keirsey
2025-07-03 14:12:20 -04:00
committed by psychedelicious
parent 71e6f00e10
commit 52dbdb7118
2 changed files with 12 additions and 18 deletions

View File

@ -895,9 +895,7 @@ class FluxDenoiseInvocation(BaseInvocation):
yield (lora_info.model, lora.weight)
del lora_info
def _build_step_callback(
self, context: InvocationContext
) -> Callable[[PipelineIntermediateState], None]:
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
def step_callback(state: PipelineIntermediateState) -> None:
# The denoise function now handles Kontext conditioning correctly,
# so we don't need to slice the latents here

View File

@ -95,44 +95,40 @@ class KontextExtension:
def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Encodes the reference image and prepares its latents and IDs."""
image = self._context.images.get_pil(self.kontext_conditioning.image.image_name)
# Calculate aspect ratio of input image
width, height = image.size
aspect_ratio = width / height
# Find the closest preferred resolution by aspect ratio
_, target_width, target_height = min(
((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS),
key=lambda x: x[0]
((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS), key=lambda x: x[0]
)
# Apply BFL's scaling formula
# This ensures compatibility with the model's training
scaled_width = 2 * int(target_width / 16)
scaled_height = 2 * int(target_height / 16)
# Resize to the exact resolution used during training
image = image.convert("RGB")
final_width = 8 * scaled_width
final_height = 8 * scaled_height
image = image.resize((final_width, final_height), Image.Resampling.LANCZOS)
# Convert to tensor with same normalization as BFL
image_np = np.array(image)
image_tensor = torch.from_numpy(image_np).float() / 127.5 - 1.0
image_tensor = einops.rearrange(image_tensor, "h w c -> 1 c h w")
image_tensor = image_tensor.to(self._device)
# Continue with VAE encoding
vae_info = self._context.models.load(self._vae_field.vae)
kontext_latents_unpacked = FluxVaeEncodeInvocation.vae_encode(
vae_info=vae_info,
image_tensor=image_tensor
)
kontext_latents_unpacked = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
# Extract tensor dimensions
batch_size, _, latent_height, latent_width = kontext_latents_unpacked.shape
# Pack the latents and generate IDs
kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype)
kontext_ids = generate_img_ids_with_offset(
@ -143,7 +139,7 @@ class KontextExtension:
dtype=self._dtype,
idx_offset=1,
)
return kontext_latents_packed, kontext_ids
def ensure_batch_size(self, target_batch_size: int) -> None: