Initial attempt at preview images

This commit is contained in:
Brandon Rising 2024-08-30 12:31:29 -04:00
parent 87261bdbc9
commit bd2540994e
2 changed files with 55 additions and 28 deletions

View File

@ -20,7 +20,11 @@ from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.util import image_to_dataURL
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.model_manager.config import CheckpointConfigBase
@invocation( @invocation(
"flux_text_to_image", "flux_text_to_image",
@ -92,6 +96,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
x, img_ids = prepare_latent_img_patches(x) x, img_ids = prepare_latent_img_patches(x)
if not isinstance(transformer_info.config, CheckpointConfigBase):
raise ValueError("Transformer provided must be a valid Checkpoint model")
is_schnell = "schnell" in transformer_info.config.config_path is_schnell = "schnell" in transformer_info.config.config_path
timesteps = get_schedule( timesteps = get_schedule(
@ -106,34 +113,46 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
with transformer_info as transformer: with transformer_info as transformer:
assert isinstance(transformer, Flux) assert isinstance(transformer, Flux)
def step_callback() -> None: def step_callback(img: torch.Tensor, state: PipelineIntermediateState) -> None:
if context.util.is_canceled(): if context.util.is_canceled():
raise CanceledException raise CanceledException
latent_rgb_factors =torch.tensor([
[-0.0371, 0.0132, 0.0587],
[ 0.0029, 0.0264, 0.0824],
[ 0.0297, -0.0723, -0.0479],
[-0.0213, 0.0071, 0.0513],
[ 0.0928, 0.0859, 0.0504],
[ 0.0023, 0.0367, 0.0112],
[ 0.0546, 0.1134, 0.1203],
[-0.0174, -0.0191, -0.0284],
[-0.0301, 0.0049, 0.0954],
[ 0.0904, 0.0623, -0.0514],
[-0.0483, 0.0213, -0.0012],
[ 0.0477, -0.0007, -0.0083],
[ 0.0936, 0.0898, 0.0930],
[-0.1173, -0.0266, -0.0854],
[-0.0004, -0.0507, -0.0019],
[-0.1204, -0.0880, -0.0643]
], dtype=img.dtype, device=img.device)
latent_image = unpack(img.float(), self.height, self.width).squeeze()
latent_image_perm = latent_image.permute(1, 2, 0).to(dtype=img.dtype, device=img.device)
latent_image = latent_image_perm @ latent_rgb_factors
latents_ubyte = (
((latent_image + 1) / 2).clamp(0, 1).mul(0xFF) # change scale from -1..1 to 0..1 # to 0..255
).to(device="cpu", dtype=torch.uint8)
image = Image.fromarray(latents_ubyte.cpu().numpy())
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
# TODO: Make this look like the image before re-enabling context._services.events.emit_invocation_denoise_progress(
# latent_image = unpack(img.float(), self.height, self.width) context._data.queue_item,
# latent_image = latent_image.squeeze() # Remove unnecessary dimensions context._data.invocation,
# flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128] state,
ProgressImage(dataURL=dataURL, width=width, height=height),
)
# # Create a new tensor of the required shape [255, 255, 3]
# latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format
# # Convert to a NumPy array and then to a PIL Image
# image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8))
# (width, height) = image.size
# width *= 8
# height *= 8
# dataURL = image_to_dataURL(image, image_format="JPEG")
# # TODO: move this whole function to invocation context to properly reference these variables
# context._services.events.emit_invocation_denoise_progress(
# context._data.queue_item,
# context._data.invocation,
# state,
# ProgressImage(dataURL=dataURL, width=width, height=height),
# )
x = denoise( x = denoise(
model=transformer, model=transformer,

View File

@ -10,7 +10,7 @@ from tqdm import tqdm
from invokeai.backend.flux.model import Flux from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.conditioner import HFEncoder from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
def get_noise( def get_noise(
num_samples: int, num_samples: int,
@ -108,9 +108,10 @@ def denoise(
vec: Tensor, vec: Tensor,
# sampling parameters # sampling parameters
timesteps: list[float], timesteps: list[float],
step_callback: Callable[[], None], step_callback: Callable[[Tensor, PipelineIntermediateState], None],
guidance: float = 4.0, guidance: float = 4.0,
): ):
step = 0
# guidance_vec is ignored for schnell. # guidance_vec is ignored for schnell.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))): for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
@ -126,7 +127,14 @@ def denoise(
) )
img = img + (t_prev - t_curr) * pred img = img + (t_prev - t_curr) * pred
step_callback() step_callback(img, PipelineIntermediateState(
step=step,
order=1,
total_steps=len(timesteps),
timestep=int(t_curr),
latents=img,
))
step+=1
return img return img