Working inference node with quantized bnb nf4 checkpoint

This commit is contained in:
Brandon Rising 2024-08-19 13:12:38 -04:00 committed by Brandon
parent 2eb87f3306
commit 81f0886d6f
2 changed files with 65 additions and 11 deletions

View File

@ -89,7 +89,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
img, img_ids = self._prepare_latent_img_patches(x) img, img_ids = self._prepare_latent_img_patches(x)
# HACK(ryand): Find a better way to determine if this is a schnell model or not. # HACK(ryand): Find a better way to determine if this is a schnell model or not.
is_schnell = "shnell" in transformer_info.config.path if transformer_info.config else "" is_schnell = "schnell" in transformer_info.config.path if transformer_info.config else ""
timesteps = get_schedule( timesteps = get_schedule(
num_steps=self.num_steps, num_steps=self.num_steps,
image_seq_len=img.shape[1], image_seq_len=img.shape[1],
@ -139,9 +139,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
img = repeat(img, "1 ... -> bs ...", bs=bs) img = repeat(img, "1 ... -> bs ...", bs=bs)
# Generate patch position ids. # Generate patch position ids.
img_ids = torch.zeros(h // 2, w // 2, 3) img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
return img, img_ids return img, img_ids
@ -155,8 +155,10 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
with vae_info as vae: with vae_info as vae:
assert isinstance(vae, AutoEncoder) assert isinstance(vae, AutoEncoder)
# TODO(ryand): Test that this works with both float16 and bfloat16. # TODO(ryand): Test that this works with both float16 and bfloat16.
with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()): # with torch.autocast(device_type=latents.device.type, dtype=torch.float32):
img = vae.decode(latents) vae.to(torch.float32)
latents.to(torch.float32)
img = vae.decode(latents)
img.clamp(-1, 1) img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c") img = rearrange(img[0], "c h w -> h w c")

View File

@ -1,6 +1,8 @@
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team # Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
"""Class for Flux model loading in InvokeAI.""" """Class for Flux model loading in InvokeAI."""
import accelerate
import torch
from dataclasses import fields from dataclasses import fields
from pathlib import Path from pathlib import Path
from typing import Any, Optional from typing import Any, Optional
@ -24,6 +26,7 @@ from invokeai.backend.model_manager.config import (
CheckpointConfigBase, CheckpointConfigBase,
CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig,
MainCheckpointConfig, MainCheckpointConfig,
MainBnbQuantized4bCheckpointConfig,
T5EncoderConfig, T5EncoderConfig,
VAECheckpointConfig, VAECheckpointConfig,
) )
@ -31,6 +34,7 @@ from invokeai.backend.model_manager.load.model_loader_registry import ModelLoade
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.silence_warnings import SilenceWarnings from invokeai.backend.util.silence_warnings import SilenceWarnings
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
app_config = get_config() app_config = get_config()
@ -62,7 +66,7 @@ class FluxVAELoader(GenericDiffusersLoader):
with SilenceWarnings(): with SilenceWarnings():
model = load_class(params).to(self._torch_dtype) model = load_class(params).to(self._torch_dtype)
# load_sft doesn't support torch.device # load_sft doesn't support torch.device
sd = load_file(model_path, device=str(TorchDevice.choose_torch_device())) sd = load_file(model_path)
model.load_state_dict(sd, strict=False, assign=True) model.load_state_dict(sd, strict=False, assign=True)
return model return model
@ -105,9 +109,9 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader):
match submodel_type: match submodel_type:
case SubModelType.Tokenizer2: case SubModelType.Tokenizer2:
return T5Tokenizer.from_pretrained(Path(config.path) / "encoder", max_length=512) return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2: case SubModelType.TextEncoder2:
return T5EncoderModel.from_pretrained(Path(config.path) / "tokenizer") return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2") #TODO: Fix hf subfolder install
raise Exception("Only Checkpoint Flux models are currently supported.") raise Exception("Only Checkpoint Flux models are currently supported.")
@ -152,7 +156,55 @@ class FluxCheckpointModel(GenericDiffusersLoader):
with SilenceWarnings(): with SilenceWarnings():
model = load_class(params).to(self._torch_dtype) model = load_class(params).to(self._torch_dtype)
# load_sft doesn't support torch.device sd = load_file(model_path)
sd = load_file(model_path, device=str(TorchDevice.choose_torch_device())) model.load_state_dict(sd, strict=False, assign=True)
return model
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b)
class FluxBnbQuantizednf4bCheckpointModel(GenericDiffusersLoader):
"""Class to load main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
raise Exception("Only Checkpoint Flux models are currently supported.")
legacy_config_path = app_config.legacy_conf_path / config.config_path
config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream:
try:
flux_conf = yaml.safe_load(stream)
except:
raise
match submodel_type:
case SubModelType.Transformer:
return self._load_from_singlefile(config, flux_conf)
raise Exception("Only Checkpoint Flux models are currently supported.")
def _load_from_singlefile(
self,
config: AnyModelConfig,
flux_conf: Any,
) -> AnyModel:
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
load_class = Flux
params = None
model_path = Path(config.path)
dataclass_fields = {f.name for f in fields(FluxParams)}
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
params = FluxParams(**filtered_data)
with SilenceWarnings():
with accelerate.init_empty_weights():
model = load_class(params)
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
# TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle
# this on GPUs without bfloat16 support.
sd = load_file(model_path)
model.load_state_dict(sd, strict=False, assign=True) model.load_state_dict(sd, strict=False, assign=True)
return model return model