mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
WIP on moving from diffusers to FLUX
This commit is contained in:
parent
3e8a550fab
commit
d7a39a4d67
@ -3,9 +3,12 @@ from typing import Literal
|
|||||||
|
|
||||||
import accelerate
|
import accelerate
|
||||||
import torch
|
import torch
|
||||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
|
||||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||||
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
from einops import rearrange, repeat
|
||||||
|
from flux.model import Flux
|
||||||
|
from flux.modules.autoencoder import AutoEncoder
|
||||||
|
from flux.sampling import denoise, get_noise, get_schedule, unpack
|
||||||
|
from flux.util import configs as flux_configs
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from transformers.models.auto import AutoModelForTextEncoding
|
from transformers.models.auto import AutoModelForTextEncoding
|
||||||
@ -21,11 +24,11 @@ from invokeai.app.invocations.fields import (
|
|||||||
)
|
)
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
|
||||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||||
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
|
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
|
||||||
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
|
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
TFluxModelKeys = Literal["flux-schnell"]
|
TFluxModelKeys = Literal["flux-schnell"]
|
||||||
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
|
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
|
||||||
@ -70,7 +73,13 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
|
# model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
|
||||||
|
flux_transformer_path = context.models.download_and_cache_model(
|
||||||
|
"https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors"
|
||||||
|
)
|
||||||
|
flux_ae_path = context.models.download_and_cache_model(
|
||||||
|
"https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/ae.safetensors"
|
||||||
|
)
|
||||||
|
|
||||||
# Load the conditioning data.
|
# Load the conditioning data.
|
||||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||||
@ -78,123 +87,155 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
flux_conditioning = cond_data.conditionings[0]
|
flux_conditioning = cond_data.conditionings[0]
|
||||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||||
|
|
||||||
latents = self._run_diffusion(context, model_path, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
|
latents = self._run_diffusion(
|
||||||
image = self._run_vae_decoding(context, model_path, latents)
|
context, flux_transformer_path, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds
|
||||||
|
)
|
||||||
|
image = self._run_vae_decoding(context, flux_ae_path, latents)
|
||||||
image_dto = context.images.save(image=image)
|
image_dto = context.images.save(image=image)
|
||||||
return ImageOutput.build(image_dto)
|
return ImageOutput.build(image_dto)
|
||||||
|
|
||||||
def _run_diffusion(
|
def _run_diffusion(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
flux_model_dir: Path,
|
flux_transformer_path: Path,
|
||||||
clip_embeddings: torch.Tensor,
|
clip_embeddings: torch.Tensor,
|
||||||
t5_embeddings: torch.Tensor,
|
t5_embeddings: torch.Tensor,
|
||||||
):
|
):
|
||||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(flux_model_dir / "scheduler", local_files_only=True)
|
inference_dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
|
||||||
|
# Prepare input noise.
|
||||||
|
# TODO(ryand): Does the seed behave the same on different devices? Should we re-implement this to always use a
|
||||||
|
# CPU RNG?
|
||||||
|
x = get_noise(
|
||||||
|
num_samples=1,
|
||||||
|
height=self.height,
|
||||||
|
width=self.width,
|
||||||
|
device=TorchDevice.choose_torch_device(),
|
||||||
|
dtype=inference_dtype,
|
||||||
|
seed=self.seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
img, img_ids = self._prepare_latent_img_patches(x)
|
||||||
|
|
||||||
|
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
|
||||||
|
is_schnell = "shnell" in str(flux_transformer_path)
|
||||||
|
timesteps = get_schedule(
|
||||||
|
num_steps=self.num_steps,
|
||||||
|
image_seq_len=img.shape[1],
|
||||||
|
shift=not is_schnell,
|
||||||
|
)
|
||||||
|
|
||||||
|
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||||
|
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||||
|
|
||||||
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
|
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
|
||||||
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
|
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
|
||||||
# if the cache is not empty.
|
# if the cache is not empty.
|
||||||
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
|
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
|
||||||
|
|
||||||
transformer_path = flux_model_dir / "transformer"
|
|
||||||
with context.models.load_local_model(
|
with context.models.load_local_model(
|
||||||
model_path=transformer_path, loader=self._load_flux_transformer
|
model_path=flux_transformer_path, loader=self._load_flux_transformer
|
||||||
) as transformer:
|
) as transformer:
|
||||||
assert isinstance(transformer, FluxTransformer2DModel)
|
assert isinstance(transformer, Flux)
|
||||||
|
|
||||||
flux_pipeline_with_transformer = FluxPipeline(
|
x = denoise(
|
||||||
scheduler=scheduler,
|
model=transformer,
|
||||||
vae=None,
|
img=img,
|
||||||
text_encoder=None,
|
img_ids=img_ids,
|
||||||
tokenizer=None,
|
txt=t5_embeddings,
|
||||||
text_encoder_2=None,
|
txt_ids=txt_ids,
|
||||||
tokenizer_2=None,
|
vec=clip_embeddings,
|
||||||
transformer=transformer,
|
timesteps=timesteps,
|
||||||
|
guidance=self.guidance,
|
||||||
)
|
)
|
||||||
|
|
||||||
dtype = torch.bfloat16
|
x = unpack(x.float(), self.height, self.width)
|
||||||
t5_embeddings = t5_embeddings.to(dtype=dtype)
|
|
||||||
clip_embeddings = clip_embeddings.to(dtype=dtype)
|
|
||||||
|
|
||||||
latents = flux_pipeline_with_transformer(
|
return x
|
||||||
height=self.height,
|
|
||||||
width=self.width,
|
|
||||||
num_inference_steps=self.num_steps,
|
|
||||||
guidance_scale=self.guidance,
|
|
||||||
generator=torch.Generator().manual_seed(self.seed),
|
|
||||||
prompt_embeds=t5_embeddings,
|
|
||||||
pooled_prompt_embeds=clip_embeddings,
|
|
||||||
output_type="latent",
|
|
||||||
return_dict=False,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
assert isinstance(latents, torch.Tensor)
|
def _prepare_latent_img_patches(self, latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
return latents
|
"""Convert an input image in latent space to patches for diffusion.
|
||||||
|
|
||||||
|
This implementation was extracted from:
|
||||||
|
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
|
||||||
|
"""
|
||||||
|
bs, c, h, w = latent_img.shape
|
||||||
|
|
||||||
|
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
|
||||||
|
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||||
|
if img.shape[0] == 1 and bs > 1:
|
||||||
|
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||||
|
|
||||||
|
# Generate patch position ids.
|
||||||
|
img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||||
|
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||||
|
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||||
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
|
return img, img_ids
|
||||||
|
|
||||||
def _run_vae_decoding(
|
def _run_vae_decoding(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
flux_model_dir: Path,
|
flux_ae_path: Path,
|
||||||
latents: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
vae_path = flux_model_dir / "vae"
|
with context.models.load_local_model(model_path=flux_ae_path, loader=self._load_flux_vae) as vae:
|
||||||
with context.models.load_local_model(model_path=vae_path, loader=self._load_flux_vae) as vae:
|
assert isinstance(vae, AutoEncoder)
|
||||||
assert isinstance(vae, AutoencoderKL)
|
# TODO(ryand): Test that this works with both float16 and bfloat16.
|
||||||
|
with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()):
|
||||||
|
img = vae.decode(latents)
|
||||||
|
|
||||||
flux_pipeline_with_vae = FluxPipeline(
|
img.clamp(-1, 1)
|
||||||
scheduler=None,
|
img = rearrange(img[0], "c h w -> h w c")
|
||||||
vae=vae,
|
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||||
text_encoder=None,
|
|
||||||
tokenizer=None,
|
|
||||||
text_encoder_2=None,
|
|
||||||
tokenizer_2=None,
|
|
||||||
transformer=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
latents = flux_pipeline_with_vae._unpack_latents(
|
return img_pil
|
||||||
latents, self.height, self.width, flux_pipeline_with_vae.vae_scale_factor
|
|
||||||
)
|
|
||||||
latents = (
|
|
||||||
latents / flux_pipeline_with_vae.vae.config.scaling_factor
|
|
||||||
) + flux_pipeline_with_vae.vae.config.shift_factor
|
|
||||||
latents = latents.to(dtype=vae.dtype)
|
|
||||||
image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0]
|
|
||||||
image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0]
|
|
||||||
|
|
||||||
assert isinstance(image, Image.Image)
|
|
||||||
return image
|
|
||||||
|
|
||||||
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
|
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
|
||||||
|
inference_dtype = TorchDevice.choose_torch_dtype()
|
||||||
if self.quantization_type == "raw":
|
if self.quantization_type == "raw":
|
||||||
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||||
elif self.quantization_type == "NF4":
|
params = flux_configs["flux-schnell"].params
|
||||||
model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
|
|
||||||
with accelerate.init_empty_weights():
|
|
||||||
empty_model = FluxTransformer2DModel.from_config(model_config)
|
|
||||||
assert isinstance(empty_model, FluxTransformer2DModel)
|
|
||||||
|
|
||||||
model_nf4_path = path / "bnb_nf4"
|
# Initialize the model on the "meta" device.
|
||||||
assert model_nf4_path.exists()
|
|
||||||
with accelerate.init_empty_weights():
|
with accelerate.init_empty_weights():
|
||||||
model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
model = Flux(params).to(inference_dtype)
|
||||||
|
|
||||||
|
state_dict = load_file(path)
|
||||||
|
# TODO(ryand): Cast the state_dict to the appropriate dtype?
|
||||||
|
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||||
|
elif self.quantization_type == "NF4":
|
||||||
|
model_path = path.parent / "bnb_nf4.safetensors"
|
||||||
|
|
||||||
|
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||||
|
params = flux_configs["flux-schnell"].params
|
||||||
|
# Initialize the model on the "meta" device.
|
||||||
|
with accelerate.init_empty_weights():
|
||||||
|
model = Flux(params)
|
||||||
|
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||||
|
|
||||||
# TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle
|
# TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle
|
||||||
# this on GPUs without bfloat16 support.
|
# this on GPUs without bfloat16 support.
|
||||||
sd = load_file(model_nf4_path / "model.safetensors")
|
state_dict = load_file(model_path)
|
||||||
model.load_state_dict(sd, strict=True, assign=True)
|
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||||
elif self.quantization_type == "llm_int8":
|
|
||||||
model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
|
|
||||||
with accelerate.init_empty_weights():
|
|
||||||
empty_model = FluxTransformer2DModel.from_config(model_config)
|
|
||||||
assert isinstance(empty_model, FluxTransformer2DModel)
|
|
||||||
model_int8_path = path / "bnb_llm_int8"
|
|
||||||
assert model_int8_path.exists()
|
|
||||||
with accelerate.init_empty_weights():
|
|
||||||
model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set())
|
|
||||||
|
|
||||||
sd = load_file(model_int8_path / "model.safetensors")
|
elif self.quantization_type == "llm_int8":
|
||||||
model.load_state_dict(sd, strict=True, assign=True)
|
raise NotImplementedError("LLM int8 quantization is not yet supported.")
|
||||||
|
# model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
|
||||||
|
# with accelerate.init_empty_weights():
|
||||||
|
# empty_model = FluxTransformer2DModel.from_config(model_config)
|
||||||
|
# assert isinstance(empty_model, FluxTransformer2DModel)
|
||||||
|
# model_int8_path = path / "bnb_llm_int8"
|
||||||
|
# assert model_int8_path.exists()
|
||||||
|
# with accelerate.init_empty_weights():
|
||||||
|
# model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set())
|
||||||
|
|
||||||
|
# sd = load_file(model_int8_path / "model.safetensors")
|
||||||
|
# model.load_state_dict(sd, strict=True, assign=True)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported quantization type: {self.quantization_type}")
|
raise ValueError(f"Unsupported quantization type: {self.quantization_type}")
|
||||||
|
|
||||||
@ -202,7 +243,12 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_flux_vae(path: Path) -> AutoencoderKL:
|
def _load_flux_vae(path: Path) -> AutoEncoder:
|
||||||
model = AutoencoderKL.from_pretrained(path, local_files_only=True)
|
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||||
assert isinstance(model, AutoencoderKL)
|
ae_params = flux_configs["flux1-schnell"].ae_params
|
||||||
return model
|
with accelerate.init_empty_weights():
|
||||||
|
ae = AutoEncoder(ae_params)
|
||||||
|
|
||||||
|
state_dict = load_file(path)
|
||||||
|
ae.load_state_dict(state_dict, strict=True, assign=True)
|
||||||
|
return ae
|
||||||
|
@ -4,7 +4,8 @@ from pathlib import Path
|
|||||||
|
|
||||||
import accelerate
|
import accelerate
|
||||||
import torch
|
import torch
|
||||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
from flux.model import Flux
|
||||||
|
from flux.util import configs as flux_configs
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
|
|
||||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||||
@ -22,22 +23,24 @@ def log_time(name: str):
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Load the FLUX transformer model onto the meta device.
|
|
||||||
model_path = Path(
|
model_path = Path(
|
||||||
"/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/"
|
"/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# inference_dtype = torch.bfloat16
|
||||||
with log_time("Intialize FLUX transformer on meta device"):
|
with log_time("Intialize FLUX transformer on meta device"):
|
||||||
model_config = FluxTransformer2DModel.load_config(model_path, local_files_only=True)
|
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||||
|
params = flux_configs["flux-schnell"].params
|
||||||
|
|
||||||
|
# Initialize the model on the "meta" device.
|
||||||
with accelerate.init_empty_weights():
|
with accelerate.init_empty_weights():
|
||||||
empty_model = FluxTransformer2DModel.from_config(model_config)
|
model = Flux(params)
|
||||||
assert isinstance(empty_model, FluxTransformer2DModel)
|
|
||||||
|
|
||||||
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
||||||
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
||||||
modules_to_not_convert: set[str] = set()
|
modules_to_not_convert: set[str] = set()
|
||||||
|
|
||||||
model_nf4_path = model_path / "bnb_nf4"
|
model_nf4_path = model_path.parent / "bnb_nf4.safetensors"
|
||||||
if model_nf4_path.exists():
|
if model_nf4_path.exists():
|
||||||
# The quantized model already exists, load it and return it.
|
# The quantized model already exists, load it and return it.
|
||||||
print(f"A pre-quantized model already exists at '{model_nf4_path}'. Attempting to load it...")
|
print(f"A pre-quantized model already exists at '{model_nf4_path}'. Attempting to load it...")
|
||||||
@ -45,12 +48,12 @@ def main():
|
|||||||
# Replace the linear layers with NF4 quantized linear layers (still on the meta device).
|
# Replace the linear layers with NF4 quantized linear layers (still on the meta device).
|
||||||
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
|
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
|
||||||
model = quantize_model_nf4(
|
model = quantize_model_nf4(
|
||||||
empty_model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
|
model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
|
||||||
)
|
)
|
||||||
|
|
||||||
with log_time("Load state dict into model"):
|
with log_time("Load state dict into model"):
|
||||||
sd = load_file(model_nf4_path / "model.safetensors")
|
state_dict = load_file(model_nf4_path)
|
||||||
model.load_state_dict(sd, strict=True, assign=True)
|
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||||
|
|
||||||
with log_time("Move model to cuda"):
|
with log_time("Move model to cuda"):
|
||||||
model = model.to("cuda")
|
model = model.to("cuda")
|
||||||
@ -63,30 +66,24 @@ def main():
|
|||||||
|
|
||||||
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
|
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
|
||||||
model = quantize_model_nf4(
|
model = quantize_model_nf4(
|
||||||
empty_model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
|
model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
|
||||||
)
|
)
|
||||||
|
|
||||||
with log_time("Load state dict into model"):
|
with log_time("Load state dict into model"):
|
||||||
# Load sharded state dict.
|
state_dict = load_file(model_path)
|
||||||
files = list(model_path.glob("*.safetensors"))
|
# TODO(ryand): Cast the state_dict to the appropriate dtype?
|
||||||
state_dict = dict()
|
|
||||||
for file in files:
|
|
||||||
sd = load_file(file)
|
|
||||||
state_dict.update(sd)
|
|
||||||
|
|
||||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||||
|
|
||||||
with log_time("Move model to cuda and quantize"):
|
with log_time("Move model to cuda and quantize"):
|
||||||
model = model.to("cuda")
|
model = model.to("cuda")
|
||||||
|
|
||||||
with log_time("Save quantized model"):
|
with log_time("Save quantized model"):
|
||||||
model_nf4_path.mkdir(parents=True, exist_ok=True)
|
model_nf4_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
output_path = model_nf4_path / "model.safetensors"
|
save_file(model.state_dict(), model_nf4_path)
|
||||||
save_file(model.state_dict(), output_path)
|
|
||||||
|
|
||||||
print(f"Successfully quantized and saved model to '{output_path}'.")
|
print(f"Successfully quantized and saved model to '{model_nf4_path}'.")
|
||||||
|
|
||||||
assert isinstance(model, FluxTransformer2DModel)
|
assert isinstance(model, Flux)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,6 +40,7 @@ dependencies = [
|
|||||||
"controlnet-aux==0.0.7",
|
"controlnet-aux==0.0.7",
|
||||||
# TODO(ryand): Bump this once the next diffusers release is ready.
|
# TODO(ryand): Bump this once the next diffusers release is ready.
|
||||||
"diffusers[torch] @ git+https://github.com/huggingface/diffusers.git@4c6152c2fb0ade468aadb417102605a07a8635d3",
|
"diffusers[torch] @ git+https://github.com/huggingface/diffusers.git@4c6152c2fb0ade468aadb417102605a07a8635d3",
|
||||||
|
"flux @ git+https://github.com/black-forest-labs/flux.git@c23ae247225daba30fbd56058d247cc1b1fc20a3",
|
||||||
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||||
"mediapipe==0.10.7", # needed for "mediapipeface" controlnet model
|
"mediapipe==0.10.7", # needed for "mediapipeface" controlnet model
|
||||||
"numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
|
"numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
|
||||||
|
Loading…
Reference in New Issue
Block a user