mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Working inference node with quantized bnb nf4 checkpoint
This commit is contained in:
parent
2eb87f3306
commit
81f0886d6f
@ -89,7 +89,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
img, img_ids = self._prepare_latent_img_patches(x)
|
img, img_ids = self._prepare_latent_img_patches(x)
|
||||||
|
|
||||||
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
|
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
|
||||||
is_schnell = "shnell" in transformer_info.config.path if transformer_info.config else ""
|
is_schnell = "schnell" in transformer_info.config.path if transformer_info.config else ""
|
||||||
timesteps = get_schedule(
|
timesteps = get_schedule(
|
||||||
num_steps=self.num_steps,
|
num_steps=self.num_steps,
|
||||||
image_seq_len=img.shape[1],
|
image_seq_len=img.shape[1],
|
||||||
@ -139,9 +139,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||||
|
|
||||||
# Generate patch position ids.
|
# Generate patch position ids.
|
||||||
img_ids = torch.zeros(h // 2, w // 2, 3)
|
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device)
|
||||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None]
|
||||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :]
|
||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
return img, img_ids
|
return img, img_ids
|
||||||
@ -155,8 +155,10 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
with vae_info as vae:
|
with vae_info as vae:
|
||||||
assert isinstance(vae, AutoEncoder)
|
assert isinstance(vae, AutoEncoder)
|
||||||
# TODO(ryand): Test that this works with both float16 and bfloat16.
|
# TODO(ryand): Test that this works with both float16 and bfloat16.
|
||||||
with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()):
|
# with torch.autocast(device_type=latents.device.type, dtype=torch.float32):
|
||||||
img = vae.decode(latents)
|
vae.to(torch.float32)
|
||||||
|
latents.to(torch.float32)
|
||||||
|
img = vae.decode(latents)
|
||||||
|
|
||||||
img.clamp(-1, 1)
|
img.clamp(-1, 1)
|
||||||
img = rearrange(img[0], "c h w -> h w c")
|
img = rearrange(img[0], "c h w -> h w c")
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
|
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
|
||||||
"""Class for Flux model loading in InvokeAI."""
|
"""Class for Flux model loading in InvokeAI."""
|
||||||
|
|
||||||
|
import accelerate
|
||||||
|
import torch
|
||||||
from dataclasses import fields
|
from dataclasses import fields
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
@ -24,6 +26,7 @@ from invokeai.backend.model_manager.config import (
|
|||||||
CheckpointConfigBase,
|
CheckpointConfigBase,
|
||||||
CLIPEmbedDiffusersConfig,
|
CLIPEmbedDiffusersConfig,
|
||||||
MainCheckpointConfig,
|
MainCheckpointConfig,
|
||||||
|
MainBnbQuantized4bCheckpointConfig,
|
||||||
T5EncoderConfig,
|
T5EncoderConfig,
|
||||||
VAECheckpointConfig,
|
VAECheckpointConfig,
|
||||||
)
|
)
|
||||||
@ -31,6 +34,7 @@ from invokeai.backend.model_manager.load.model_loader_registry import ModelLoade
|
|||||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
|
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||||
|
|
||||||
app_config = get_config()
|
app_config = get_config()
|
||||||
|
|
||||||
@ -62,7 +66,7 @@ class FluxVAELoader(GenericDiffusersLoader):
|
|||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
model = load_class(params).to(self._torch_dtype)
|
model = load_class(params).to(self._torch_dtype)
|
||||||
# load_sft doesn't support torch.device
|
# load_sft doesn't support torch.device
|
||||||
sd = load_file(model_path, device=str(TorchDevice.choose_torch_device()))
|
sd = load_file(model_path)
|
||||||
model.load_state_dict(sd, strict=False, assign=True)
|
model.load_state_dict(sd, strict=False, assign=True)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
@ -105,9 +109,9 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader):
|
|||||||
|
|
||||||
match submodel_type:
|
match submodel_type:
|
||||||
case SubModelType.Tokenizer2:
|
case SubModelType.Tokenizer2:
|
||||||
return T5Tokenizer.from_pretrained(Path(config.path) / "encoder", max_length=512)
|
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||||
case SubModelType.TextEncoder2:
|
case SubModelType.TextEncoder2:
|
||||||
return T5EncoderModel.from_pretrained(Path(config.path) / "tokenizer")
|
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2") #TODO: Fix hf subfolder install
|
||||||
|
|
||||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
raise Exception("Only Checkpoint Flux models are currently supported.")
|
||||||
|
|
||||||
@ -152,7 +156,55 @@ class FluxCheckpointModel(GenericDiffusersLoader):
|
|||||||
|
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
model = load_class(params).to(self._torch_dtype)
|
model = load_class(params).to(self._torch_dtype)
|
||||||
# load_sft doesn't support torch.device
|
sd = load_file(model_path)
|
||||||
sd = load_file(model_path, device=str(TorchDevice.choose_torch_device()))
|
model.load_state_dict(sd, strict=False, assign=True)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b)
|
||||||
|
class FluxBnbQuantizednf4bCheckpointModel(GenericDiffusersLoader):
|
||||||
|
"""Class to load main models."""
|
||||||
|
|
||||||
|
def _load_model(
|
||||||
|
self,
|
||||||
|
config: AnyModelConfig,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
) -> AnyModel:
|
||||||
|
if not isinstance(config, CheckpointConfigBase):
|
||||||
|
raise Exception("Only Checkpoint Flux models are currently supported.")
|
||||||
|
legacy_config_path = app_config.legacy_conf_path / config.config_path
|
||||||
|
config_path = legacy_config_path.as_posix()
|
||||||
|
with open(config_path, "r") as stream:
|
||||||
|
try:
|
||||||
|
flux_conf = yaml.safe_load(stream)
|
||||||
|
except:
|
||||||
|
raise
|
||||||
|
|
||||||
|
match submodel_type:
|
||||||
|
case SubModelType.Transformer:
|
||||||
|
return self._load_from_singlefile(config, flux_conf)
|
||||||
|
|
||||||
|
raise Exception("Only Checkpoint Flux models are currently supported.")
|
||||||
|
|
||||||
|
def _load_from_singlefile(
|
||||||
|
self,
|
||||||
|
config: AnyModelConfig,
|
||||||
|
flux_conf: Any,
|
||||||
|
) -> AnyModel:
|
||||||
|
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
|
||||||
|
load_class = Flux
|
||||||
|
params = None
|
||||||
|
model_path = Path(config.path)
|
||||||
|
dataclass_fields = {f.name for f in fields(FluxParams)}
|
||||||
|
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
|
||||||
|
params = FluxParams(**filtered_data)
|
||||||
|
|
||||||
|
with SilenceWarnings():
|
||||||
|
with accelerate.init_empty_weights():
|
||||||
|
model = load_class(params)
|
||||||
|
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||||
|
# TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle
|
||||||
|
# this on GPUs without bfloat16 support.
|
||||||
|
sd = load_file(model_path)
|
||||||
model.load_state_dict(sd, strict=False, assign=True)
|
model.load_state_dict(sd, strict=False, assign=True)
|
||||||
return model
|
return model
|
||||||
|
Loading…
Reference in New Issue
Block a user