Setup scaffolding for in progress images and add ability to cancel the flux node

This commit is contained in:
Brandon Rising 2024-08-24 11:01:16 -04:00 committed by Brandon
parent a808ce81fd
commit 56b9906e2e
5 changed files with 138 additions and 4 deletions

View File

@ -1,3 +1,4 @@
import numpy as np
import torch import torch
from einops import rearrange from einops import rearrange
from PIL import Image from PIL import Image
@ -13,12 +14,15 @@ from invokeai.app.invocations.fields import (
) )
from invokeai.app.invocations.model import TransformerField, VAEField from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.session_processor.session_processor_common import CanceledException, ProgressImage
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.step_callback import PipelineIntermediateState
from invokeai.backend.flux.model import Flux from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.util import image_to_dataURL
@invocation( @invocation(
@ -108,6 +112,35 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
with transformer_info as transformer: with transformer_info as transformer:
assert isinstance(transformer, Flux) assert isinstance(transformer, Flux)
def step_callback(img: torch.Tensor, state: PipelineIntermediateState) -> None:
if context.util.is_canceled():
raise CanceledException
# TODO: Make this look like the image
latent_image = unpack(img.float(), self.height, self.width)
latent_image = latent_image.squeeze() # Remove unnecessary dimensions
flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128]
# Create a new tensor of the required shape [255, 255, 3]
latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format
# Convert to a NumPy array and then to a PIL Image
image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8))
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
# TODO: move this whole function to invocation context to properly reference these variables
context._services.events.emit_invocation_denoise_progress(
context._data.queue_item,
context._data.invocation,
state,
ProgressImage(dataURL=dataURL, width=width, height=height),
)
x = denoise( x = denoise(
model=transformer, model=transformer,
img=img, img=img,
@ -116,6 +149,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
txt_ids=txt_ids, txt_ids=txt_ids,
vec=clip_embeddings, vec=clip_embeddings,
timesteps=timesteps, timesteps=timesteps,
step_callback=step_callback,
guidance=self.guidance, guidance=self.guidance,
) )

View File

@ -8,6 +8,7 @@ from einops import rearrange, repeat
from torch import Tensor from torch import Tensor
from tqdm import tqdm from tqdm import tqdm
from invokeai.app.util.step_callback import PipelineIntermediateState
from invokeai.backend.flux.model import Flux from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.conditioner import HFEncoder from invokeai.backend.flux.modules.conditioner import HFEncoder
@ -108,6 +109,7 @@ def denoise(
vec: Tensor, vec: Tensor,
# sampling parameters # sampling parameters
timesteps: list[float], timesteps: list[float],
step_callback: Callable[[Tensor, PipelineIntermediateState], None],
guidance: float = 4.0, guidance: float = 4.0,
): ):
dtype = model.txt_in.bias.dtype dtype = model.txt_in.bias.dtype
@ -121,6 +123,7 @@ def denoise(
# this is ignored for schnell # this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
step_count = 0
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))): for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model( pred = model(
@ -134,6 +137,17 @@ def denoise(
) )
img = img + (t_prev - t_curr) * pred img = img + (t_prev - t_curr) * pred
step_callback(
img,
PipelineIntermediateState(
step=step_count,
order=0,
total_steps=len(timesteps),
timestep=math.floor(t_curr),
latents=img,
),
)
step_count += 1
return img return img

View File

@ -0,0 +1,86 @@
# Initially pulled from https://github.com/black-forest-labs/flux
import os
from dataclasses import dataclass
from invokeai.backend.flux.model import FluxParams
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
@dataclass
class ModelSpec:
params: FluxParams
ae_params: AutoEncoderParams
ckpt_path: str | None
ae_path: str | None
repo_id: str | None
repo_flow: str | None
repo_ae: str | None
configs = {
"flux-dev": ModelSpec(
repo_id="black-forest-labs/FLUX.1-dev",
repo_flow="flux1-dev.safetensors",
repo_ae="ae.safetensors",
ckpt_path=os.getenv("FLUX_DEV"),
params=FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
ae_path=os.getenv("AE"),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
"flux-schnell": ModelSpec(
repo_id="black-forest-labs/FLUX.1-schnell",
repo_flow="flux1-schnell.safetensors",
repo_ae="ae.safetensors",
ckpt_path=os.getenv("FLUX_SCHNELL"),
params=FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=False,
),
ae_path=os.getenv("AE"),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
}

View File

@ -1,8 +1,8 @@
from pathlib import Path from pathlib import Path
import accelerate import accelerate
from flux.model import Flux from invokeai.backend.flux.model import Flux
from flux.util import configs as flux_configs from invokeai.backend.flux.util import configs as flux_configs
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8 from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8

View File

@ -4,10 +4,10 @@ from pathlib import Path
import accelerate import accelerate
import torch import torch
from flux.model import Flux
from flux.util import configs as flux_configs
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import configs as flux_configs
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4