mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Rename params for flux and flux vae, add comments explaining use of the config_path in model config
This commit is contained in:
parent
2d185fb766
commit
65bb46bcca
@ -25,7 +25,8 @@ max_seq_lengths: Dict[str, Literal[256, 512]] = {
|
||||
}
|
||||
|
||||
|
||||
ae_params = AutoEncoderParams(
|
||||
ae_params = {
|
||||
"flux": AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
@ -35,16 +36,12 @@ ae_params = AutoEncoderParams(
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
configs = {
|
||||
"flux-dev": ModelSpec(
|
||||
repo_id="black-forest-labs/FLUX.1-dev",
|
||||
repo_flow="flux1-dev.safetensors",
|
||||
repo_ae="ae.safetensors",
|
||||
ckpt_path=os.getenv("FLUX_DEV"),
|
||||
params=FluxParams(
|
||||
params = {
|
||||
"flux-dev": FluxParams(
|
||||
in_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
@ -58,25 +55,7 @@ configs = {
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
),
|
||||
ae_path=os.getenv("AE"),
|
||||
ae_params=AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
),
|
||||
),
|
||||
"flux-schnell": ModelSpec(
|
||||
repo_id="black-forest-labs/FLUX.1-schnell",
|
||||
repo_flow="flux1-schnell.safetensors",
|
||||
repo_ae="ae.safetensors",
|
||||
ckpt_path=os.getenv("FLUX_SCHNELL"),
|
||||
params=FluxParams(
|
||||
"flux-schnell": FluxParams(
|
||||
in_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
@ -90,17 +69,4 @@ configs = {
|
||||
qkv_bias=True,
|
||||
guidance_embed=False,
|
||||
),
|
||||
ae_path=os.getenv("AE"),
|
||||
ae_params=AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
),
|
||||
),
|
||||
}
|
||||
|
@ -12,7 +12,7 @@ from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CL
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.flux.util import ae_params, configs
|
||||
from invokeai.backend.flux.util import ae_params, params
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
@ -59,7 +59,7 @@ class FluxVAELoader(ModelLoader):
|
||||
model_path = Path(config.path)
|
||||
|
||||
with SilenceWarnings():
|
||||
model = AutoEncoder(ae_params)
|
||||
model = AutoEncoder(ae_params[config.config_path])
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
model.to(dtype=self._torch_dtype)
|
||||
@ -188,7 +188,7 @@ class FluxCheckpointModel(ModelLoader):
|
||||
model_path = Path(config.path)
|
||||
|
||||
with SilenceWarnings():
|
||||
model = Flux(configs[config.config_path].params)
|
||||
model = Flux(params[config.config_path])
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
@ -227,7 +227,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
||||
|
||||
with SilenceWarnings():
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(configs[config.config_path].params)
|
||||
model = Flux(params[config.config_path])
|
||||
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
|
@ -329,8 +329,16 @@ class ModelProbe(object):
|
||||
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
||||
state_dict = checkpoint.get("state_dict") or checkpoint
|
||||
if "guidance_in.out_layer.weight" in state_dict:
|
||||
# For flux, this is a key in invokeai.backend.flux.util.params
|
||||
# Due to model type and format being the descriminator for model configs this
|
||||
# is used rather than attempting to support flux with separate model types and format
|
||||
# If changed in the future, please fix me
|
||||
config_file = "flux-dev"
|
||||
else:
|
||||
# For flux, this is a key in invokeai.backend.flux.util.params
|
||||
# Due to model type and format being the descriminator for model configs this
|
||||
# is used rather than attempting to support flux with separate model types and format
|
||||
# If changed in the future, please fix me
|
||||
config_file = "flux-schnell"
|
||||
else:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
@ -345,7 +353,11 @@ class ModelProbe(object):
|
||||
)
|
||||
elif model_type is ModelType.VAE:
|
||||
config_file = (
|
||||
"flux/flux1-vae.yaml"
|
||||
# For flux, this is a key in invokeai.backend.flux.util.ae_params
|
||||
# Due to model type and format being the descriminator for model configs this
|
||||
# is used rather than attempting to support flux with separate model types and format
|
||||
# If changed in the future, please fix me
|
||||
"flux"
|
||||
if base_type is BaseModelType.Flux
|
||||
else "stable-diffusion/v1-inference.yaml"
|
||||
if base_type is BaseModelType.StableDiffusion1
|
||||
|
@ -4,7 +4,7 @@ import accelerate
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.util import configs as flux_configs
|
||||
from invokeai.backend.flux.util import params
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
||||
from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time
|
||||
|
||||
@ -22,11 +22,11 @@ def main():
|
||||
|
||||
with log_time("Intialize FLUX transformer on meta device"):
|
||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||
params = flux_configs["flux-schnell"].params
|
||||
p = params["flux-schnell"]
|
||||
|
||||
# Initialize the model on the "meta" device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(params)
|
||||
model = Flux(p)
|
||||
|
||||
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
||||
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
||||
|
@ -7,7 +7,7 @@ import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.util import configs as flux_configs
|
||||
from invokeai.backend.flux.util import params
|
||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||
|
||||
|
||||
@ -35,11 +35,11 @@ def main():
|
||||
# inference_dtype = torch.bfloat16
|
||||
with log_time("Intialize FLUX transformer on meta device"):
|
||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||
params = flux_configs["flux-schnell"].params
|
||||
p = params["flux-schnell"]
|
||||
|
||||
# Initialize the model on the "meta" device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(params)
|
||||
model = Flux(p)
|
||||
|
||||
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
||||
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
||||
|
Loading…
Reference in New Issue
Block a user