mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix FLUX output image clamping. And a few other minor fixes to make inference work with the full bfloat16 FLUX transformer model.
This commit is contained in:
parent
a63f842a13
commit
0c5e11f521
@ -17,9 +17,9 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, unpack
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -90,7 +90,11 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
img, img_ids = self._prepare_latent_img_patches(x)
|
||||
|
||||
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
|
||||
is_schnell = "schnell" in transformer_info.config.config_path if transformer_info.config and isinstance(transformer_info.config, CheckpointConfigBase) else ""
|
||||
is_schnell = (
|
||||
"schnell" in transformer_info.config.config_path
|
||||
if transformer_info.config and isinstance(transformer_info.config, CheckpointConfigBase)
|
||||
else ""
|
||||
)
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=img.shape[1],
|
||||
@ -161,7 +165,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
latents.to(torch.float32)
|
||||
img = vae.decode(latents)
|
||||
|
||||
img.clamp(-1, 1)
|
||||
img = img.clamp(-1, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c")
|
||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||
|
||||
|
@ -104,9 +104,18 @@ def denoise(
|
||||
timesteps: list[float],
|
||||
guidance: float = 4.0,
|
||||
):
|
||||
dtype = model.txt_in.bias.dtype
|
||||
|
||||
# TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller.
|
||||
img = img.to(dtype=dtype)
|
||||
img_ids = img_ids.to(dtype=dtype)
|
||||
txt = txt.to(dtype=dtype)
|
||||
txt_ids = txt_ids.to(dtype=dtype)
|
||||
vec = vec.to(dtype=dtype)
|
||||
|
||||
# this is ignored for schnell
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:], strict=False):
|
||||
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:], strict=True):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
pred = model(
|
||||
img=img,
|
||||
|
@ -1,12 +1,12 @@
|
||||
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
|
||||
"""Class for Flux model loading in InvokeAI."""
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from dataclasses import fields
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
import yaml
|
||||
from safetensors.torch import load_file
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
@ -25,15 +25,15 @@ from invokeai.backend.model_manager import (
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
CLIPEmbedDiffusersConfig,
|
||||
MainCheckpointConfig,
|
||||
MainBnbQuantized4bCheckpointConfig,
|
||||
MainCheckpointConfig,
|
||||
T5EncoderConfig,
|
||||
VAECheckpointConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
app_config = get_config()
|
||||
|
||||
@ -109,7 +109,9 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader):
|
||||
case SubModelType.Tokenizer2:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
case SubModelType.TextEncoder2:
|
||||
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2") #TODO: Fix hf subfolder install
|
||||
return T5EncoderModel.from_pretrained(
|
||||
Path(config.path) / "text_encoder_2"
|
||||
) # TODO: Fix hf subfolder install
|
||||
|
||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
||||
|
||||
@ -153,7 +155,7 @@ class FluxCheckpointModel(GenericDiffusersLoader):
|
||||
params = FluxParams(**filtered_data)
|
||||
|
||||
with SilenceWarnings():
|
||||
model = load_class(params).to(self._torch_dtype)
|
||||
model = load_class(params)
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, strict=False, assign=True)
|
||||
return model
|
||||
|
Loading…
Reference in New Issue
Block a user