From bd2540994e7cbd858847f54d44d5d8d05bcd6989 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Fri, 30 Aug 2024 12:31:29 -0400 Subject: [PATCH] Initial attempt at preview images --- .../app/invocations/flux_text_to_image.py | 69 ++++++++++++------- invokeai/backend/flux/sampling.py | 14 +++- 2 files changed, 55 insertions(+), 28 deletions(-) diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index 248122d8cd..7a791c6eb8 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -20,7 +20,11 @@ from invokeai.backend.flux.modules.autoencoder import AutoEncoder from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.util import image_to_dataURL +from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState +from invokeai.app.services.session_processor.session_processor_common import ProgressImage +from invokeai.backend.model_manager.config import CheckpointConfigBase @invocation( "flux_text_to_image", @@ -92,6 +96,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): x, img_ids = prepare_latent_img_patches(x) + if not isinstance(transformer_info.config, CheckpointConfigBase): + raise ValueError("Transformer provided must be a valid Checkpoint model") + is_schnell = "schnell" in transformer_info.config.config_path timesteps = get_schedule( @@ -106,34 +113,46 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): with transformer_info as transformer: assert isinstance(transformer, Flux) - def step_callback() -> None: + def step_callback(img: torch.Tensor, state: PipelineIntermediateState) -> None: if context.util.is_canceled(): raise CanceledException + latent_rgb_factors =torch.tensor([ + [-0.0371, 0.0132, 0.0587], + [ 0.0029, 0.0264, 0.0824], + [ 0.0297, -0.0723, -0.0479], + [-0.0213, 0.0071, 0.0513], + [ 0.0928, 0.0859, 0.0504], + [ 0.0023, 0.0367, 0.0112], + [ 0.0546, 0.1134, 0.1203], + [-0.0174, -0.0191, -0.0284], + [-0.0301, 0.0049, 0.0954], + [ 0.0904, 0.0623, -0.0514], + [-0.0483, 0.0213, -0.0012], + [ 0.0477, -0.0007, -0.0083], + [ 0.0936, 0.0898, 0.0930], + [-0.1173, -0.0266, -0.0854], + [-0.0004, -0.0507, -0.0019], + [-0.1204, -0.0880, -0.0643] + ], dtype=img.dtype, device=img.device) + latent_image = unpack(img.float(), self.height, self.width).squeeze() + latent_image_perm = latent_image.permute(1, 2, 0).to(dtype=img.dtype, device=img.device) + latent_image = latent_image_perm @ latent_rgb_factors + latents_ubyte = ( + ((latent_image + 1) / 2).clamp(0, 1).mul(0xFF) # change scale from -1..1 to 0..1 # to 0..255 + ).to(device="cpu", dtype=torch.uint8) + image = Image.fromarray(latents_ubyte.cpu().numpy()) + (width, height) = image.size + width *= 8 + height *= 8 + dataURL = image_to_dataURL(image, image_format="JPEG") - # TODO: Make this look like the image before re-enabling - # latent_image = unpack(img.float(), self.height, self.width) - # latent_image = latent_image.squeeze() # Remove unnecessary dimensions - # flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128] - - # # Create a new tensor of the required shape [255, 255, 3] - # latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format - - # # Convert to a NumPy array and then to a PIL Image - # image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8)) - - # (width, height) = image.size - # width *= 8 - # height *= 8 - - # dataURL = image_to_dataURL(image, image_format="JPEG") - - # # TODO: move this whole function to invocation context to properly reference these variables - # context._services.events.emit_invocation_denoise_progress( - # context._data.queue_item, - # context._data.invocation, - # state, - # ProgressImage(dataURL=dataURL, width=width, height=height), - # ) + context._services.events.emit_invocation_denoise_progress( + context._data.queue_item, + context._data.invocation, + state, + ProgressImage(dataURL=dataURL, width=width, height=height), + ) + x = denoise( model=transformer, diff --git a/invokeai/backend/flux/sampling.py b/invokeai/backend/flux/sampling.py index 7a35b0aedf..dbc448d843 100644 --- a/invokeai/backend/flux/sampling.py +++ b/invokeai/backend/flux/sampling.py @@ -10,7 +10,7 @@ from tqdm import tqdm from invokeai.backend.flux.model import Flux from invokeai.backend.flux.modules.conditioner import HFEncoder - +from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState def get_noise( num_samples: int, @@ -108,9 +108,10 @@ def denoise( vec: Tensor, # sampling parameters timesteps: list[float], - step_callback: Callable[[], None], + step_callback: Callable[[Tensor, PipelineIntermediateState], None], guidance: float = 4.0, ): + step = 0 # guidance_vec is ignored for schnell. guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))): @@ -126,7 +127,14 @@ def denoise( ) img = img + (t_prev - t_curr) * pred - step_callback() + step_callback(img, PipelineIntermediateState( + step=step, + order=1, + total_steps=len(timesteps), + timestep=int(t_curr), + latents=img, + )) + step+=1 return img