Rename params for flux and flux vae, add comments explaining use of the config_path in model config

This commit is contained in:
Brandon Rising 2024-08-26 15:42:42 -04:00 committed by Brandon
parent 2d185fb766
commit 65bb46bcca
5 changed files with 63 additions and 85 deletions

View File

@ -25,82 +25,48 @@ max_seq_lengths: Dict[str, Literal[256, 512]] = {
} }
ae_params = AutoEncoderParams( ae_params = {
resolution=256, "flux": AutoEncoderParams(
in_channels=3, resolution=256,
ch=128, in_channels=3,
out_ch=3, ch=128,
ch_mult=[1, 2, 4, 4], out_ch=3,
num_res_blocks=2, ch_mult=[1, 2, 4, 4],
z_channels=16, num_res_blocks=2,
scale_factor=0.3611, z_channels=16,
shift_factor=0.1159, scale_factor=0.3611,
) shift_factor=0.1159,
)
}
configs = { params = {
"flux-dev": ModelSpec( "flux-dev": FluxParams(
repo_id="black-forest-labs/FLUX.1-dev", in_channels=64,
repo_flow="flux1-dev.safetensors", vec_in_dim=768,
repo_ae="ae.safetensors", context_in_dim=4096,
ckpt_path=os.getenv("FLUX_DEV"), hidden_size=3072,
params=FluxParams( mlp_ratio=4.0,
in_channels=64, num_heads=24,
vec_in_dim=768, depth=19,
context_in_dim=4096, depth_single_blocks=38,
hidden_size=3072, axes_dim=[16, 56, 56],
mlp_ratio=4.0, theta=10_000,
num_heads=24, qkv_bias=True,
depth=19, guidance_embed=True,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
ae_path=os.getenv("AE"),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
), ),
"flux-schnell": ModelSpec( "flux-schnell": FluxParams(
repo_id="black-forest-labs/FLUX.1-schnell", in_channels=64,
repo_flow="flux1-schnell.safetensors", vec_in_dim=768,
repo_ae="ae.safetensors", context_in_dim=4096,
ckpt_path=os.getenv("FLUX_SCHNELL"), hidden_size=3072,
params=FluxParams( mlp_ratio=4.0,
in_channels=64, num_heads=24,
vec_in_dim=768, depth=19,
context_in_dim=4096, depth_single_blocks=38,
hidden_size=3072, axes_dim=[16, 56, 56],
mlp_ratio=4.0, theta=10_000,
num_heads=24, qkv_bias=True,
depth=19, guidance_embed=False,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=False,
),
ae_path=os.getenv("AE"),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
), ),
} }

View File

@ -12,7 +12,7 @@ from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CL
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config.config_default import get_config
from invokeai.backend.flux.model import Flux from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.util import ae_params, configs from invokeai.backend.flux.util import ae_params, params
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
AnyModel, AnyModel,
AnyModelConfig, AnyModelConfig,
@ -59,7 +59,7 @@ class FluxVAELoader(ModelLoader):
model_path = Path(config.path) model_path = Path(config.path)
with SilenceWarnings(): with SilenceWarnings():
model = AutoEncoder(ae_params) model = AutoEncoder(ae_params[config.config_path])
sd = load_file(model_path) sd = load_file(model_path)
model.load_state_dict(sd, assign=True) model.load_state_dict(sd, assign=True)
model.to(dtype=self._torch_dtype) model.to(dtype=self._torch_dtype)
@ -188,7 +188,7 @@ class FluxCheckpointModel(ModelLoader):
model_path = Path(config.path) model_path = Path(config.path)
with SilenceWarnings(): with SilenceWarnings():
model = Flux(configs[config.config_path].params) model = Flux(params[config.config_path])
sd = load_file(model_path) sd = load_file(model_path)
model.load_state_dict(sd, assign=True) model.load_state_dict(sd, assign=True)
return model return model
@ -227,7 +227,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
with SilenceWarnings(): with SilenceWarnings():
with accelerate.init_empty_weights(): with accelerate.init_empty_weights():
model = Flux(configs[config.config_path].params) model = Flux(params[config.config_path])
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
sd = load_file(model_path) sd = load_file(model_path)
model.load_state_dict(sd, assign=True) model.load_state_dict(sd, assign=True)

View File

@ -329,8 +329,16 @@ class ModelProbe(object):
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path) checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
state_dict = checkpoint.get("state_dict") or checkpoint state_dict = checkpoint.get("state_dict") or checkpoint
if "guidance_in.out_layer.weight" in state_dict: if "guidance_in.out_layer.weight" in state_dict:
# For flux, this is a key in invokeai.backend.flux.util.params
# Due to model type and format being the descriminator for model configs this
# is used rather than attempting to support flux with separate model types and format
# If changed in the future, please fix me
config_file = "flux-dev" config_file = "flux-dev"
else: else:
# For flux, this is a key in invokeai.backend.flux.util.params
# Due to model type and format being the descriminator for model configs this
# is used rather than attempting to support flux with separate model types and format
# If changed in the future, please fix me
config_file = "flux-schnell" config_file = "flux-schnell"
else: else:
config_file = LEGACY_CONFIGS[base_type][variant_type] config_file = LEGACY_CONFIGS[base_type][variant_type]
@ -345,7 +353,11 @@ class ModelProbe(object):
) )
elif model_type is ModelType.VAE: elif model_type is ModelType.VAE:
config_file = ( config_file = (
"flux/flux1-vae.yaml" # For flux, this is a key in invokeai.backend.flux.util.ae_params
# Due to model type and format being the descriminator for model configs this
# is used rather than attempting to support flux with separate model types and format
# If changed in the future, please fix me
"flux"
if base_type is BaseModelType.Flux if base_type is BaseModelType.Flux
else "stable-diffusion/v1-inference.yaml" else "stable-diffusion/v1-inference.yaml"
if base_type is BaseModelType.StableDiffusion1 if base_type is BaseModelType.StableDiffusion1

View File

@ -4,7 +4,7 @@ import accelerate
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from invokeai.backend.flux.model import Flux from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import configs as flux_configs from invokeai.backend.flux.util import params
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8 from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time
@ -22,11 +22,11 @@ def main():
with log_time("Intialize FLUX transformer on meta device"): with log_time("Intialize FLUX transformer on meta device"):
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config. # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
params = flux_configs["flux-schnell"].params p = params["flux-schnell"]
# Initialize the model on the "meta" device. # Initialize the model on the "meta" device.
with accelerate.init_empty_weights(): with accelerate.init_empty_weights():
model = Flux(params) model = Flux(p)
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate # TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize. # `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.

View File

@ -7,7 +7,7 @@ import torch
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from invokeai.backend.flux.model import Flux from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import configs as flux_configs from invokeai.backend.flux.util import params
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
@ -35,11 +35,11 @@ def main():
# inference_dtype = torch.bfloat16 # inference_dtype = torch.bfloat16
with log_time("Intialize FLUX transformer on meta device"): with log_time("Intialize FLUX transformer on meta device"):
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config. # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
params = flux_configs["flux-schnell"].params p = params["flux-schnell"]
# Initialize the model on the "meta" device. # Initialize the model on the "meta" device.
with accelerate.init_empty_weights(): with accelerate.init_empty_weights():
model = Flux(params) model = Flux(p)
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate # TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize. # `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.