mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix FLUX output image clamping. And a few other minor fixes to make inference work with the full bfloat16 FLUX transformer model.
This commit is contained in:
parent
a63f842a13
commit
0c5e11f521
@ -17,9 +17,9 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
|||||||
from invokeai.backend.flux.model import Flux
|
from invokeai.backend.flux.model import Flux
|
||||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||||
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, unpack
|
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, unpack
|
||||||
|
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -90,7 +90,11 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
img, img_ids = self._prepare_latent_img_patches(x)
|
img, img_ids = self._prepare_latent_img_patches(x)
|
||||||
|
|
||||||
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
|
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
|
||||||
is_schnell = "schnell" in transformer_info.config.config_path if transformer_info.config and isinstance(transformer_info.config, CheckpointConfigBase) else ""
|
is_schnell = (
|
||||||
|
"schnell" in transformer_info.config.config_path
|
||||||
|
if transformer_info.config and isinstance(transformer_info.config, CheckpointConfigBase)
|
||||||
|
else ""
|
||||||
|
)
|
||||||
timesteps = get_schedule(
|
timesteps = get_schedule(
|
||||||
num_steps=self.num_steps,
|
num_steps=self.num_steps,
|
||||||
image_seq_len=img.shape[1],
|
image_seq_len=img.shape[1],
|
||||||
@ -161,7 +165,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
latents.to(torch.float32)
|
latents.to(torch.float32)
|
||||||
img = vae.decode(latents)
|
img = vae.decode(latents)
|
||||||
|
|
||||||
img.clamp(-1, 1)
|
img = img.clamp(-1, 1)
|
||||||
img = rearrange(img[0], "c h w -> h w c")
|
img = rearrange(img[0], "c h w -> h w c")
|
||||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||||
|
|
||||||
|
@ -104,9 +104,18 @@ def denoise(
|
|||||||
timesteps: list[float],
|
timesteps: list[float],
|
||||||
guidance: float = 4.0,
|
guidance: float = 4.0,
|
||||||
):
|
):
|
||||||
|
dtype = model.txt_in.bias.dtype
|
||||||
|
|
||||||
|
# TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller.
|
||||||
|
img = img.to(dtype=dtype)
|
||||||
|
img_ids = img_ids.to(dtype=dtype)
|
||||||
|
txt = txt.to(dtype=dtype)
|
||||||
|
txt_ids = txt_ids.to(dtype=dtype)
|
||||||
|
vec = vec.to(dtype=dtype)
|
||||||
|
|
||||||
# this is ignored for schnell
|
# this is ignored for schnell
|
||||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||||
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:], strict=False):
|
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:], strict=True):
|
||||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||||
pred = model(
|
pred = model(
|
||||||
img=img,
|
img=img,
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
|
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
|
||||||
"""Class for Flux model loading in InvokeAI."""
|
"""Class for Flux model loading in InvokeAI."""
|
||||||
|
|
||||||
import accelerate
|
|
||||||
import torch
|
|
||||||
from dataclasses import fields
|
from dataclasses import fields
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import accelerate
|
||||||
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||||
@ -25,15 +25,15 @@ from invokeai.backend.model_manager import (
|
|||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
CheckpointConfigBase,
|
CheckpointConfigBase,
|
||||||
CLIPEmbedDiffusersConfig,
|
CLIPEmbedDiffusersConfig,
|
||||||
MainCheckpointConfig,
|
|
||||||
MainBnbQuantized4bCheckpointConfig,
|
MainBnbQuantized4bCheckpointConfig,
|
||||||
|
MainCheckpointConfig,
|
||||||
T5EncoderConfig,
|
T5EncoderConfig,
|
||||||
VAECheckpointConfig,
|
VAECheckpointConfig,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
|
||||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||||
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
|
|
||||||
app_config = get_config()
|
app_config = get_config()
|
||||||
|
|
||||||
@ -109,7 +109,9 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader):
|
|||||||
case SubModelType.Tokenizer2:
|
case SubModelType.Tokenizer2:
|
||||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||||
case SubModelType.TextEncoder2:
|
case SubModelType.TextEncoder2:
|
||||||
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2") #TODO: Fix hf subfolder install
|
return T5EncoderModel.from_pretrained(
|
||||||
|
Path(config.path) / "text_encoder_2"
|
||||||
|
) # TODO: Fix hf subfolder install
|
||||||
|
|
||||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
raise Exception("Only Checkpoint Flux models are currently supported.")
|
||||||
|
|
||||||
@ -153,7 +155,7 @@ class FluxCheckpointModel(GenericDiffusersLoader):
|
|||||||
params = FluxParams(**filtered_data)
|
params = FluxParams(**filtered_data)
|
||||||
|
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
model = load_class(params).to(self._torch_dtype)
|
model = load_class(params)
|
||||||
sd = load_file(model_path)
|
sd = load_file(model_path)
|
||||||
model.load_state_dict(sd, strict=False, assign=True)
|
model.load_state_dict(sd, strict=False, assign=True)
|
||||||
return model
|
return model
|
||||||
|
Loading…
Reference in New Issue
Block a user