Working inference node with quantized bnb nf4 checkpoint

This commit is contained in:
Brandon Rising 2024-08-19 13:12:38 -04:00
parent 4fb5529493
commit be6cb2c07c
2 changed files with 65 additions and 11 deletions

View File

@ -89,7 +89,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
img, img_ids = self._prepare_latent_img_patches(x)
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
is_schnell = "shnell" in transformer_info.config.path if transformer_info.config else ""
is_schnell = "schnell" in transformer_info.config.path if transformer_info.config else ""
timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=img.shape[1],
@ -139,9 +139,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
img = repeat(img, "1 ... -> bs ...", bs=bs)
# Generate patch position ids.
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
return img, img_ids
@ -155,7 +155,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
# TODO(ryand): Test that this works with both float16 and bfloat16.
with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()):
# with torch.autocast(device_type=latents.device.type, dtype=torch.float32):
vae.to(torch.float32)
latents.to(torch.float32)
img = vae.decode(latents)
img.clamp(-1, 1)

View File

@ -1,6 +1,8 @@
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
"""Class for Flux model loading in InvokeAI."""
import accelerate
import torch
from dataclasses import fields
from pathlib import Path
from typing import Any, Optional
@ -24,6 +26,7 @@ from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
CLIPEmbedDiffusersConfig,
MainCheckpointConfig,
MainBnbQuantized4bCheckpointConfig,
T5EncoderConfig,
VAECheckpointConfig,
)
@ -31,6 +34,7 @@ from invokeai.backend.model_manager.load.model_loader_registry import ModelLoade
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.silence_warnings import SilenceWarnings
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
app_config = get_config()
@ -62,7 +66,7 @@ class FluxVAELoader(GenericDiffusersLoader):
with SilenceWarnings():
model = load_class(params).to(self._torch_dtype)
# load_sft doesn't support torch.device
sd = load_file(model_path, device=str(TorchDevice.choose_torch_device()))
sd = load_file(model_path)
model.load_state_dict(sd, strict=False, assign=True)
return model
@ -105,9 +109,9 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader):
match submodel_type:
case SubModelType.Tokenizer2:
return T5Tokenizer.from_pretrained(Path(config.path) / "encoder", max_length=512)
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2:
return T5EncoderModel.from_pretrained(Path(config.path) / "tokenizer")
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2") #TODO: Fix hf subfolder install
raise Exception("Only Checkpoint Flux models are currently supported.")
@ -152,7 +156,55 @@ class FluxCheckpointModel(GenericDiffusersLoader):
with SilenceWarnings():
model = load_class(params).to(self._torch_dtype)
# load_sft doesn't support torch.device
sd = load_file(model_path, device=str(TorchDevice.choose_torch_device()))
sd = load_file(model_path)
model.load_state_dict(sd, strict=False, assign=True)
return model
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b)
class FluxBnbQuantizednf4bCheckpointModel(GenericDiffusersLoader):
"""Class to load main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
raise Exception("Only Checkpoint Flux models are currently supported.")
legacy_config_path = app_config.legacy_conf_path / config.config_path
config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream:
try:
flux_conf = yaml.safe_load(stream)
except:
raise
match submodel_type:
case SubModelType.Transformer:
return self._load_from_singlefile(config, flux_conf)
raise Exception("Only Checkpoint Flux models are currently supported.")
def _load_from_singlefile(
self,
config: AnyModelConfig,
flux_conf: Any,
) -> AnyModel:
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
load_class = Flux
params = None
model_path = Path(config.path)
dataclass_fields = {f.name for f in fields(FluxParams)}
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
params = FluxParams(**filtered_data)
with SilenceWarnings():
with accelerate.init_empty_weights():
model = load_class(params)
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
# TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle
# this on GPUs without bfloat16 support.
sd = load_file(model_path)
model.load_state_dict(sd, strict=False, assign=True)
return model