Rename params for flux and flux vae, add comments explaining use of the config_path in model config

This commit is contained in:
Brandon Rising 2024-08-26 15:42:42 -04:00 committed by Brandon
parent 2d185fb766
commit 65bb46bcca
5 changed files with 63 additions and 85 deletions

View File

@ -25,7 +25,8 @@ max_seq_lengths: Dict[str, Literal[256, 512]] = {
}
ae_params = AutoEncoderParams(
ae_params = {
"flux": AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
@ -36,15 +37,11 @@ ae_params = AutoEncoderParams(
scale_factor=0.3611,
shift_factor=0.1159,
)
}
configs = {
"flux-dev": ModelSpec(
repo_id="black-forest-labs/FLUX.1-dev",
repo_flow="flux1-dev.safetensors",
repo_ae="ae.safetensors",
ckpt_path=os.getenv("FLUX_DEV"),
params=FluxParams(
params = {
"flux-dev": FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
@ -58,25 +55,7 @@ configs = {
qkv_bias=True,
guidance_embed=True,
),
ae_path=os.getenv("AE"),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
"flux-schnell": ModelSpec(
repo_id="black-forest-labs/FLUX.1-schnell",
repo_flow="flux1-schnell.safetensors",
repo_ae="ae.safetensors",
ckpt_path=os.getenv("FLUX_SCHNELL"),
params=FluxParams(
"flux-schnell": FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
@ -90,17 +69,4 @@ configs = {
qkv_bias=True,
guidance_embed=False,
),
ae_path=os.getenv("AE"),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
}

View File

@ -12,7 +12,7 @@ from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CL
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.util import ae_params, configs
from invokeai.backend.flux.util import ae_params, params
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
@ -59,7 +59,7 @@ class FluxVAELoader(ModelLoader):
model_path = Path(config.path)
with SilenceWarnings():
model = AutoEncoder(ae_params)
model = AutoEncoder(ae_params[config.config_path])
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
model.to(dtype=self._torch_dtype)
@ -188,7 +188,7 @@ class FluxCheckpointModel(ModelLoader):
model_path = Path(config.path)
with SilenceWarnings():
model = Flux(configs[config.config_path].params)
model = Flux(params[config.config_path])
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
return model
@ -227,7 +227,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
with SilenceWarnings():
with accelerate.init_empty_weights():
model = Flux(configs[config.config_path].params)
model = Flux(params[config.config_path])
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)

View File

@ -329,8 +329,16 @@ class ModelProbe(object):
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
state_dict = checkpoint.get("state_dict") or checkpoint
if "guidance_in.out_layer.weight" in state_dict:
# For flux, this is a key in invokeai.backend.flux.util.params
# Due to model type and format being the descriminator for model configs this
# is used rather than attempting to support flux with separate model types and format
# If changed in the future, please fix me
config_file = "flux-dev"
else:
# For flux, this is a key in invokeai.backend.flux.util.params
# Due to model type and format being the descriminator for model configs this
# is used rather than attempting to support flux with separate model types and format
# If changed in the future, please fix me
config_file = "flux-schnell"
else:
config_file = LEGACY_CONFIGS[base_type][variant_type]
@ -345,7 +353,11 @@ class ModelProbe(object):
)
elif model_type is ModelType.VAE:
config_file = (
"flux/flux1-vae.yaml"
# For flux, this is a key in invokeai.backend.flux.util.ae_params
# Due to model type and format being the descriminator for model configs this
# is used rather than attempting to support flux with separate model types and format
# If changed in the future, please fix me
"flux"
if base_type is BaseModelType.Flux
else "stable-diffusion/v1-inference.yaml"
if base_type is BaseModelType.StableDiffusion1

View File

@ -4,7 +4,7 @@ import accelerate
from safetensors.torch import load_file, save_file
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import configs as flux_configs
from invokeai.backend.flux.util import params
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time
@ -22,11 +22,11 @@ def main():
with log_time("Intialize FLUX transformer on meta device"):
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
params = flux_configs["flux-schnell"].params
p = params["flux-schnell"]
# Initialize the model on the "meta" device.
with accelerate.init_empty_weights():
model = Flux(params)
model = Flux(p)
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.

View File

@ -7,7 +7,7 @@ import torch
from safetensors.torch import load_file, save_file
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import configs as flux_configs
from invokeai.backend.flux.util import params
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
@ -35,11 +35,11 @@ def main():
# inference_dtype = torch.bfloat16
with log_time("Intialize FLUX transformer on meta device"):
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
params = flux_configs["flux-schnell"].params
p = params["flux-schnell"]
# Initialize the model on the "meta" device.
with accelerate.init_empty_weights():
model = Flux(params)
model = Flux(p)
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.