mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Initial attempt at preview images
This commit is contained in:
parent
87261bdbc9
commit
bd2540994e
@ -20,7 +20,11 @@ from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.util import image_to_dataURL
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||
|
||||
@invocation(
|
||||
"flux_text_to_image",
|
||||
@ -92,6 +96,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
x, img_ids = prepare_latent_img_patches(x)
|
||||
|
||||
if not isinstance(transformer_info.config, CheckpointConfigBase):
|
||||
raise ValueError("Transformer provided must be a valid Checkpoint model")
|
||||
|
||||
is_schnell = "schnell" in transformer_info.config.config_path
|
||||
|
||||
timesteps = get_schedule(
|
||||
@ -106,34 +113,46 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
with transformer_info as transformer:
|
||||
assert isinstance(transformer, Flux)
|
||||
|
||||
def step_callback() -> None:
|
||||
def step_callback(img: torch.Tensor, state: PipelineIntermediateState) -> None:
|
||||
if context.util.is_canceled():
|
||||
raise CanceledException
|
||||
latent_rgb_factors =torch.tensor([
|
||||
[-0.0371, 0.0132, 0.0587],
|
||||
[ 0.0029, 0.0264, 0.0824],
|
||||
[ 0.0297, -0.0723, -0.0479],
|
||||
[-0.0213, 0.0071, 0.0513],
|
||||
[ 0.0928, 0.0859, 0.0504],
|
||||
[ 0.0023, 0.0367, 0.0112],
|
||||
[ 0.0546, 0.1134, 0.1203],
|
||||
[-0.0174, -0.0191, -0.0284],
|
||||
[-0.0301, 0.0049, 0.0954],
|
||||
[ 0.0904, 0.0623, -0.0514],
|
||||
[-0.0483, 0.0213, -0.0012],
|
||||
[ 0.0477, -0.0007, -0.0083],
|
||||
[ 0.0936, 0.0898, 0.0930],
|
||||
[-0.1173, -0.0266, -0.0854],
|
||||
[-0.0004, -0.0507, -0.0019],
|
||||
[-0.1204, -0.0880, -0.0643]
|
||||
], dtype=img.dtype, device=img.device)
|
||||
latent_image = unpack(img.float(), self.height, self.width).squeeze()
|
||||
latent_image_perm = latent_image.permute(1, 2, 0).to(dtype=img.dtype, device=img.device)
|
||||
latent_image = latent_image_perm @ latent_rgb_factors
|
||||
latents_ubyte = (
|
||||
((latent_image + 1) / 2).clamp(0, 1).mul(0xFF) # change scale from -1..1 to 0..1 # to 0..255
|
||||
).to(device="cpu", dtype=torch.uint8)
|
||||
image = Image.fromarray(latents_ubyte.cpu().numpy())
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
height *= 8
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
# TODO: Make this look like the image before re-enabling
|
||||
# latent_image = unpack(img.float(), self.height, self.width)
|
||||
# latent_image = latent_image.squeeze() # Remove unnecessary dimensions
|
||||
# flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128]
|
||||
context._services.events.emit_invocation_denoise_progress(
|
||||
context._data.queue_item,
|
||||
context._data.invocation,
|
||||
state,
|
||||
ProgressImage(dataURL=dataURL, width=width, height=height),
|
||||
)
|
||||
|
||||
# # Create a new tensor of the required shape [255, 255, 3]
|
||||
# latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format
|
||||
|
||||
# # Convert to a NumPy array and then to a PIL Image
|
||||
# image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8))
|
||||
|
||||
# (width, height) = image.size
|
||||
# width *= 8
|
||||
# height *= 8
|
||||
|
||||
# dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
# # TODO: move this whole function to invocation context to properly reference these variables
|
||||
# context._services.events.emit_invocation_denoise_progress(
|
||||
# context._data.queue_item,
|
||||
# context._data.invocation,
|
||||
# state,
|
||||
# ProgressImage(dataURL=dataURL, width=width, height=height),
|
||||
# )
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
|
@ -10,7 +10,7 @@ from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
|
||||
def get_noise(
|
||||
num_samples: int,
|
||||
@ -108,9 +108,10 @@ def denoise(
|
||||
vec: Tensor,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
step_callback: Callable[[], None],
|
||||
step_callback: Callable[[Tensor, PipelineIntermediateState], None],
|
||||
guidance: float = 4.0,
|
||||
):
|
||||
step = 0
|
||||
# guidance_vec is ignored for schnell.
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||||
@ -126,7 +127,14 @@ def denoise(
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
step_callback()
|
||||
step_callback(img, PipelineIntermediateState(
|
||||
step=step,
|
||||
order=1,
|
||||
total_steps=len(timesteps),
|
||||
timestep=int(t_curr),
|
||||
latents=img,
|
||||
))
|
||||
step+=1
|
||||
|
||||
return img
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user