mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Setup scaffolding for in progress images and add ability to cancel the flux node
This commit is contained in:
parent
a808ce81fd
commit
56b9906e2e
@ -1,3 +1,4 @@
|
|||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -13,12 +14,15 @@ from invokeai.app.invocations.fields import (
|
|||||||
)
|
)
|
||||||
from invokeai.app.invocations.model import TransformerField, VAEField
|
from invokeai.app.invocations.model import TransformerField, VAEField
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
|
from invokeai.app.services.session_processor.session_processor_common import CanceledException, ProgressImage
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.app.util.step_callback import PipelineIntermediateState
|
||||||
from invokeai.backend.flux.model import Flux
|
from invokeai.backend.flux.model import Flux
|
||||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||||
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack
|
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
from invokeai.backend.util.util import image_to_dataURL
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -108,6 +112,35 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
with transformer_info as transformer:
|
with transformer_info as transformer:
|
||||||
assert isinstance(transformer, Flux)
|
assert isinstance(transformer, Flux)
|
||||||
|
|
||||||
|
def step_callback(img: torch.Tensor, state: PipelineIntermediateState) -> None:
|
||||||
|
if context.util.is_canceled():
|
||||||
|
raise CanceledException
|
||||||
|
|
||||||
|
# TODO: Make this look like the image
|
||||||
|
latent_image = unpack(img.float(), self.height, self.width)
|
||||||
|
latent_image = latent_image.squeeze() # Remove unnecessary dimensions
|
||||||
|
flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128]
|
||||||
|
|
||||||
|
# Create a new tensor of the required shape [255, 255, 3]
|
||||||
|
latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format
|
||||||
|
|
||||||
|
# Convert to a NumPy array and then to a PIL Image
|
||||||
|
image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8))
|
||||||
|
|
||||||
|
(width, height) = image.size
|
||||||
|
width *= 8
|
||||||
|
height *= 8
|
||||||
|
|
||||||
|
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||||
|
|
||||||
|
# TODO: move this whole function to invocation context to properly reference these variables
|
||||||
|
context._services.events.emit_invocation_denoise_progress(
|
||||||
|
context._data.queue_item,
|
||||||
|
context._data.invocation,
|
||||||
|
state,
|
||||||
|
ProgressImage(dataURL=dataURL, width=width, height=height),
|
||||||
|
)
|
||||||
|
|
||||||
x = denoise(
|
x = denoise(
|
||||||
model=transformer,
|
model=transformer,
|
||||||
img=img,
|
img=img,
|
||||||
@ -116,6 +149,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
txt_ids=txt_ids,
|
txt_ids=txt_ids,
|
||||||
vec=clip_embeddings,
|
vec=clip_embeddings,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
|
step_callback=step_callback,
|
||||||
guidance=self.guidance,
|
guidance=self.guidance,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ from einops import rearrange, repeat
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from invokeai.app.util.step_callback import PipelineIntermediateState
|
||||||
from invokeai.backend.flux.model import Flux
|
from invokeai.backend.flux.model import Flux
|
||||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||||
|
|
||||||
@ -108,6 +109,7 @@ def denoise(
|
|||||||
vec: Tensor,
|
vec: Tensor,
|
||||||
# sampling parameters
|
# sampling parameters
|
||||||
timesteps: list[float],
|
timesteps: list[float],
|
||||||
|
step_callback: Callable[[Tensor, PipelineIntermediateState], None],
|
||||||
guidance: float = 4.0,
|
guidance: float = 4.0,
|
||||||
):
|
):
|
||||||
dtype = model.txt_in.bias.dtype
|
dtype = model.txt_in.bias.dtype
|
||||||
@ -121,6 +123,7 @@ def denoise(
|
|||||||
|
|
||||||
# this is ignored for schnell
|
# this is ignored for schnell
|
||||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||||
|
step_count = 0
|
||||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||||
pred = model(
|
pred = model(
|
||||||
@ -134,6 +137,17 @@ def denoise(
|
|||||||
)
|
)
|
||||||
|
|
||||||
img = img + (t_prev - t_curr) * pred
|
img = img + (t_prev - t_curr) * pred
|
||||||
|
step_callback(
|
||||||
|
img,
|
||||||
|
PipelineIntermediateState(
|
||||||
|
step=step_count,
|
||||||
|
order=0,
|
||||||
|
total_steps=len(timesteps),
|
||||||
|
timestep=math.floor(t_curr),
|
||||||
|
latents=img,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
step_count += 1
|
||||||
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
86
invokeai/backend/flux/util.py
Normal file
86
invokeai/backend/flux/util.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from invokeai.backend.flux.model import FluxParams
|
||||||
|
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelSpec:
|
||||||
|
params: FluxParams
|
||||||
|
ae_params: AutoEncoderParams
|
||||||
|
ckpt_path: str | None
|
||||||
|
ae_path: str | None
|
||||||
|
repo_id: str | None
|
||||||
|
repo_flow: str | None
|
||||||
|
repo_ae: str | None
|
||||||
|
|
||||||
|
|
||||||
|
configs = {
|
||||||
|
"flux-dev": ModelSpec(
|
||||||
|
repo_id="black-forest-labs/FLUX.1-dev",
|
||||||
|
repo_flow="flux1-dev.safetensors",
|
||||||
|
repo_ae="ae.safetensors",
|
||||||
|
ckpt_path=os.getenv("FLUX_DEV"),
|
||||||
|
params=FluxParams(
|
||||||
|
in_channels=64,
|
||||||
|
vec_in_dim=768,
|
||||||
|
context_in_dim=4096,
|
||||||
|
hidden_size=3072,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
num_heads=24,
|
||||||
|
depth=19,
|
||||||
|
depth_single_blocks=38,
|
||||||
|
axes_dim=[16, 56, 56],
|
||||||
|
theta=10_000,
|
||||||
|
qkv_bias=True,
|
||||||
|
guidance_embed=True,
|
||||||
|
),
|
||||||
|
ae_path=os.getenv("AE"),
|
||||||
|
ae_params=AutoEncoderParams(
|
||||||
|
resolution=256,
|
||||||
|
in_channels=3,
|
||||||
|
ch=128,
|
||||||
|
out_ch=3,
|
||||||
|
ch_mult=[1, 2, 4, 4],
|
||||||
|
num_res_blocks=2,
|
||||||
|
z_channels=16,
|
||||||
|
scale_factor=0.3611,
|
||||||
|
shift_factor=0.1159,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"flux-schnell": ModelSpec(
|
||||||
|
repo_id="black-forest-labs/FLUX.1-schnell",
|
||||||
|
repo_flow="flux1-schnell.safetensors",
|
||||||
|
repo_ae="ae.safetensors",
|
||||||
|
ckpt_path=os.getenv("FLUX_SCHNELL"),
|
||||||
|
params=FluxParams(
|
||||||
|
in_channels=64,
|
||||||
|
vec_in_dim=768,
|
||||||
|
context_in_dim=4096,
|
||||||
|
hidden_size=3072,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
num_heads=24,
|
||||||
|
depth=19,
|
||||||
|
depth_single_blocks=38,
|
||||||
|
axes_dim=[16, 56, 56],
|
||||||
|
theta=10_000,
|
||||||
|
qkv_bias=True,
|
||||||
|
guidance_embed=False,
|
||||||
|
),
|
||||||
|
ae_path=os.getenv("AE"),
|
||||||
|
ae_params=AutoEncoderParams(
|
||||||
|
resolution=256,
|
||||||
|
in_channels=3,
|
||||||
|
ch=128,
|
||||||
|
out_ch=3,
|
||||||
|
ch_mult=[1, 2, 4, 4],
|
||||||
|
num_res_blocks=2,
|
||||||
|
z_channels=16,
|
||||||
|
scale_factor=0.3611,
|
||||||
|
shift_factor=0.1159,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
}
|
@ -1,8 +1,8 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import accelerate
|
import accelerate
|
||||||
from flux.model import Flux
|
from invokeai.backend.flux.model import Flux
|
||||||
from flux.util import configs as flux_configs
|
from invokeai.backend.flux.util import configs as flux_configs
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
|
|
||||||
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
||||||
|
@ -4,10 +4,10 @@ from pathlib import Path
|
|||||||
|
|
||||||
import accelerate
|
import accelerate
|
||||||
import torch
|
import torch
|
||||||
from flux.model import Flux
|
|
||||||
from flux.util import configs as flux_configs
|
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
|
|
||||||
|
from invokeai.backend.flux.model import Flux
|
||||||
|
from invokeai.backend.flux.util import configs as flux_configs
|
||||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user