mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Working inference node with quantized bnb nf4 checkpoint
This commit is contained in:
parent
2eb87f3306
commit
81f0886d6f
@ -89,7 +89,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
img, img_ids = self._prepare_latent_img_patches(x)
|
||||
|
||||
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
|
||||
is_schnell = "shnell" in transformer_info.config.path if transformer_info.config else ""
|
||||
is_schnell = "schnell" in transformer_info.config.path if transformer_info.config else ""
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=img.shape[1],
|
||||
@ -139,9 +139,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
# Generate patch position ids.
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
return img, img_ids
|
||||
@ -155,7 +155,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
# TODO(ryand): Test that this works with both float16 and bfloat16.
|
||||
with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()):
|
||||
# with torch.autocast(device_type=latents.device.type, dtype=torch.float32):
|
||||
vae.to(torch.float32)
|
||||
latents.to(torch.float32)
|
||||
img = vae.decode(latents)
|
||||
|
||||
img.clamp(-1, 1)
|
||||
|
@ -1,6 +1,8 @@
|
||||
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
|
||||
"""Class for Flux model loading in InvokeAI."""
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from dataclasses import fields
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
@ -24,6 +26,7 @@ from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
CLIPEmbedDiffusersConfig,
|
||||
MainCheckpointConfig,
|
||||
MainBnbQuantized4bCheckpointConfig,
|
||||
T5EncoderConfig,
|
||||
VAECheckpointConfig,
|
||||
)
|
||||
@ -31,6 +34,7 @@ from invokeai.backend.model_manager.load.model_loader_registry import ModelLoade
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||
|
||||
app_config = get_config()
|
||||
|
||||
@ -62,7 +66,7 @@ class FluxVAELoader(GenericDiffusersLoader):
|
||||
with SilenceWarnings():
|
||||
model = load_class(params).to(self._torch_dtype)
|
||||
# load_sft doesn't support torch.device
|
||||
sd = load_file(model_path, device=str(TorchDevice.choose_torch_device()))
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, strict=False, assign=True)
|
||||
|
||||
return model
|
||||
@ -105,9 +109,9 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader):
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer2:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "encoder", max_length=512)
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
case SubModelType.TextEncoder2:
|
||||
return T5EncoderModel.from_pretrained(Path(config.path) / "tokenizer")
|
||||
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2") #TODO: Fix hf subfolder install
|
||||
|
||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
||||
|
||||
@ -152,7 +156,55 @@ class FluxCheckpointModel(GenericDiffusersLoader):
|
||||
|
||||
with SilenceWarnings():
|
||||
model = load_class(params).to(self._torch_dtype)
|
||||
# load_sft doesn't support torch.device
|
||||
sd = load_file(model_path, device=str(TorchDevice.choose_torch_device()))
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, strict=False, assign=True)
|
||||
return model
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b)
|
||||
class FluxBnbQuantizednf4bCheckpointModel(GenericDiffusersLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
||||
legacy_config_path = app_config.legacy_conf_path / config.config_path
|
||||
config_path = legacy_config_path.as_posix()
|
||||
with open(config_path, "r") as stream:
|
||||
try:
|
||||
flux_conf = yaml.safe_load(stream)
|
||||
except:
|
||||
raise
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Transformer:
|
||||
return self._load_from_singlefile(config, flux_conf)
|
||||
|
||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
||||
|
||||
def _load_from_singlefile(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
flux_conf: Any,
|
||||
) -> AnyModel:
|
||||
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
|
||||
load_class = Flux
|
||||
params = None
|
||||
model_path = Path(config.path)
|
||||
dataclass_fields = {f.name for f in fields(FluxParams)}
|
||||
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
|
||||
params = FluxParams(**filtered_data)
|
||||
|
||||
with SilenceWarnings():
|
||||
with accelerate.init_empty_weights():
|
||||
model = load_class(params)
|
||||
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||
# TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle
|
||||
# this on GPUs without bfloat16 support.
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, strict=False, assign=True)
|
||||
return model
|
||||
|
Loading…
Reference in New Issue
Block a user