mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Initial attempt at preview images
This commit is contained in:
parent
87261bdbc9
commit
bd2540994e
@ -20,7 +20,11 @@ from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
|||||||
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack
|
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
from invokeai.backend.util.util import image_to_dataURL
|
||||||
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||||
|
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"flux_text_to_image",
|
"flux_text_to_image",
|
||||||
@ -92,6 +96,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
x, img_ids = prepare_latent_img_patches(x)
|
x, img_ids = prepare_latent_img_patches(x)
|
||||||
|
|
||||||
|
if not isinstance(transformer_info.config, CheckpointConfigBase):
|
||||||
|
raise ValueError("Transformer provided must be a valid Checkpoint model")
|
||||||
|
|
||||||
is_schnell = "schnell" in transformer_info.config.config_path
|
is_schnell = "schnell" in transformer_info.config.config_path
|
||||||
|
|
||||||
timesteps = get_schedule(
|
timesteps = get_schedule(
|
||||||
@ -106,34 +113,46 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
with transformer_info as transformer:
|
with transformer_info as transformer:
|
||||||
assert isinstance(transformer, Flux)
|
assert isinstance(transformer, Flux)
|
||||||
|
|
||||||
def step_callback() -> None:
|
def step_callback(img: torch.Tensor, state: PipelineIntermediateState) -> None:
|
||||||
if context.util.is_canceled():
|
if context.util.is_canceled():
|
||||||
raise CanceledException
|
raise CanceledException
|
||||||
|
latent_rgb_factors =torch.tensor([
|
||||||
|
[-0.0371, 0.0132, 0.0587],
|
||||||
|
[ 0.0029, 0.0264, 0.0824],
|
||||||
|
[ 0.0297, -0.0723, -0.0479],
|
||||||
|
[-0.0213, 0.0071, 0.0513],
|
||||||
|
[ 0.0928, 0.0859, 0.0504],
|
||||||
|
[ 0.0023, 0.0367, 0.0112],
|
||||||
|
[ 0.0546, 0.1134, 0.1203],
|
||||||
|
[-0.0174, -0.0191, -0.0284],
|
||||||
|
[-0.0301, 0.0049, 0.0954],
|
||||||
|
[ 0.0904, 0.0623, -0.0514],
|
||||||
|
[-0.0483, 0.0213, -0.0012],
|
||||||
|
[ 0.0477, -0.0007, -0.0083],
|
||||||
|
[ 0.0936, 0.0898, 0.0930],
|
||||||
|
[-0.1173, -0.0266, -0.0854],
|
||||||
|
[-0.0004, -0.0507, -0.0019],
|
||||||
|
[-0.1204, -0.0880, -0.0643]
|
||||||
|
], dtype=img.dtype, device=img.device)
|
||||||
|
latent_image = unpack(img.float(), self.height, self.width).squeeze()
|
||||||
|
latent_image_perm = latent_image.permute(1, 2, 0).to(dtype=img.dtype, device=img.device)
|
||||||
|
latent_image = latent_image_perm @ latent_rgb_factors
|
||||||
|
latents_ubyte = (
|
||||||
|
((latent_image + 1) / 2).clamp(0, 1).mul(0xFF) # change scale from -1..1 to 0..1 # to 0..255
|
||||||
|
).to(device="cpu", dtype=torch.uint8)
|
||||||
|
image = Image.fromarray(latents_ubyte.cpu().numpy())
|
||||||
|
(width, height) = image.size
|
||||||
|
width *= 8
|
||||||
|
height *= 8
|
||||||
|
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||||
|
|
||||||
# TODO: Make this look like the image before re-enabling
|
context._services.events.emit_invocation_denoise_progress(
|
||||||
# latent_image = unpack(img.float(), self.height, self.width)
|
context._data.queue_item,
|
||||||
# latent_image = latent_image.squeeze() # Remove unnecessary dimensions
|
context._data.invocation,
|
||||||
# flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128]
|
state,
|
||||||
|
ProgressImage(dataURL=dataURL, width=width, height=height),
|
||||||
|
)
|
||||||
|
|
||||||
# # Create a new tensor of the required shape [255, 255, 3]
|
|
||||||
# latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format
|
|
||||||
|
|
||||||
# # Convert to a NumPy array and then to a PIL Image
|
|
||||||
# image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8))
|
|
||||||
|
|
||||||
# (width, height) = image.size
|
|
||||||
# width *= 8
|
|
||||||
# height *= 8
|
|
||||||
|
|
||||||
# dataURL = image_to_dataURL(image, image_format="JPEG")
|
|
||||||
|
|
||||||
# # TODO: move this whole function to invocation context to properly reference these variables
|
|
||||||
# context._services.events.emit_invocation_denoise_progress(
|
|
||||||
# context._data.queue_item,
|
|
||||||
# context._data.invocation,
|
|
||||||
# state,
|
|
||||||
# ProgressImage(dataURL=dataURL, width=width, height=height),
|
|
||||||
# )
|
|
||||||
|
|
||||||
x = denoise(
|
x = denoise(
|
||||||
model=transformer,
|
model=transformer,
|
||||||
|
@ -10,7 +10,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from invokeai.backend.flux.model import Flux
|
from invokeai.backend.flux.model import Flux
|
||||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||||
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||||
|
|
||||||
def get_noise(
|
def get_noise(
|
||||||
num_samples: int,
|
num_samples: int,
|
||||||
@ -108,9 +108,10 @@ def denoise(
|
|||||||
vec: Tensor,
|
vec: Tensor,
|
||||||
# sampling parameters
|
# sampling parameters
|
||||||
timesteps: list[float],
|
timesteps: list[float],
|
||||||
step_callback: Callable[[], None],
|
step_callback: Callable[[Tensor, PipelineIntermediateState], None],
|
||||||
guidance: float = 4.0,
|
guidance: float = 4.0,
|
||||||
):
|
):
|
||||||
|
step = 0
|
||||||
# guidance_vec is ignored for schnell.
|
# guidance_vec is ignored for schnell.
|
||||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||||||
@ -126,7 +127,14 @@ def denoise(
|
|||||||
)
|
)
|
||||||
|
|
||||||
img = img + (t_prev - t_curr) * pred
|
img = img + (t_prev - t_curr) * pred
|
||||||
step_callback()
|
step_callback(img, PipelineIntermediateState(
|
||||||
|
step=step,
|
||||||
|
order=1,
|
||||||
|
total_steps=len(timesteps),
|
||||||
|
timestep=int(t_curr),
|
||||||
|
latents=img,
|
||||||
|
))
|
||||||
|
step+=1
|
||||||
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user