mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Rename params for flux and flux vae, add comments explaining use of the config_path in model config
This commit is contained in:
parent
2d185fb766
commit
65bb46bcca
@ -25,82 +25,48 @@ max_seq_lengths: Dict[str, Literal[256, 512]] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
ae_params = AutoEncoderParams(
|
ae_params = {
|
||||||
resolution=256,
|
"flux": AutoEncoderParams(
|
||||||
in_channels=3,
|
resolution=256,
|
||||||
ch=128,
|
in_channels=3,
|
||||||
out_ch=3,
|
ch=128,
|
||||||
ch_mult=[1, 2, 4, 4],
|
out_ch=3,
|
||||||
num_res_blocks=2,
|
ch_mult=[1, 2, 4, 4],
|
||||||
z_channels=16,
|
num_res_blocks=2,
|
||||||
scale_factor=0.3611,
|
z_channels=16,
|
||||||
shift_factor=0.1159,
|
scale_factor=0.3611,
|
||||||
)
|
shift_factor=0.1159,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
configs = {
|
params = {
|
||||||
"flux-dev": ModelSpec(
|
"flux-dev": FluxParams(
|
||||||
repo_id="black-forest-labs/FLUX.1-dev",
|
in_channels=64,
|
||||||
repo_flow="flux1-dev.safetensors",
|
vec_in_dim=768,
|
||||||
repo_ae="ae.safetensors",
|
context_in_dim=4096,
|
||||||
ckpt_path=os.getenv("FLUX_DEV"),
|
hidden_size=3072,
|
||||||
params=FluxParams(
|
mlp_ratio=4.0,
|
||||||
in_channels=64,
|
num_heads=24,
|
||||||
vec_in_dim=768,
|
depth=19,
|
||||||
context_in_dim=4096,
|
depth_single_blocks=38,
|
||||||
hidden_size=3072,
|
axes_dim=[16, 56, 56],
|
||||||
mlp_ratio=4.0,
|
theta=10_000,
|
||||||
num_heads=24,
|
qkv_bias=True,
|
||||||
depth=19,
|
guidance_embed=True,
|
||||||
depth_single_blocks=38,
|
|
||||||
axes_dim=[16, 56, 56],
|
|
||||||
theta=10_000,
|
|
||||||
qkv_bias=True,
|
|
||||||
guidance_embed=True,
|
|
||||||
),
|
|
||||||
ae_path=os.getenv("AE"),
|
|
||||||
ae_params=AutoEncoderParams(
|
|
||||||
resolution=256,
|
|
||||||
in_channels=3,
|
|
||||||
ch=128,
|
|
||||||
out_ch=3,
|
|
||||||
ch_mult=[1, 2, 4, 4],
|
|
||||||
num_res_blocks=2,
|
|
||||||
z_channels=16,
|
|
||||||
scale_factor=0.3611,
|
|
||||||
shift_factor=0.1159,
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
"flux-schnell": ModelSpec(
|
"flux-schnell": FluxParams(
|
||||||
repo_id="black-forest-labs/FLUX.1-schnell",
|
in_channels=64,
|
||||||
repo_flow="flux1-schnell.safetensors",
|
vec_in_dim=768,
|
||||||
repo_ae="ae.safetensors",
|
context_in_dim=4096,
|
||||||
ckpt_path=os.getenv("FLUX_SCHNELL"),
|
hidden_size=3072,
|
||||||
params=FluxParams(
|
mlp_ratio=4.0,
|
||||||
in_channels=64,
|
num_heads=24,
|
||||||
vec_in_dim=768,
|
depth=19,
|
||||||
context_in_dim=4096,
|
depth_single_blocks=38,
|
||||||
hidden_size=3072,
|
axes_dim=[16, 56, 56],
|
||||||
mlp_ratio=4.0,
|
theta=10_000,
|
||||||
num_heads=24,
|
qkv_bias=True,
|
||||||
depth=19,
|
guidance_embed=False,
|
||||||
depth_single_blocks=38,
|
|
||||||
axes_dim=[16, 56, 56],
|
|
||||||
theta=10_000,
|
|
||||||
qkv_bias=True,
|
|
||||||
guidance_embed=False,
|
|
||||||
),
|
|
||||||
ae_path=os.getenv("AE"),
|
|
||||||
ae_params=AutoEncoderParams(
|
|
||||||
resolution=256,
|
|
||||||
in_channels=3,
|
|
||||||
ch=128,
|
|
||||||
out_ch=3,
|
|
||||||
ch_mult=[1, 2, 4, 4],
|
|
||||||
num_res_blocks=2,
|
|
||||||
z_channels=16,
|
|
||||||
scale_factor=0.3611,
|
|
||||||
shift_factor=0.1159,
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
@ -12,7 +12,7 @@ from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CL
|
|||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.backend.flux.model import Flux
|
from invokeai.backend.flux.model import Flux
|
||||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||||
from invokeai.backend.flux.util import ae_params, configs
|
from invokeai.backend.flux.util import ae_params, params
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
AnyModel,
|
AnyModel,
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
@ -59,7 +59,7 @@ class FluxVAELoader(ModelLoader):
|
|||||||
model_path = Path(config.path)
|
model_path = Path(config.path)
|
||||||
|
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
model = AutoEncoder(ae_params)
|
model = AutoEncoder(ae_params[config.config_path])
|
||||||
sd = load_file(model_path)
|
sd = load_file(model_path)
|
||||||
model.load_state_dict(sd, assign=True)
|
model.load_state_dict(sd, assign=True)
|
||||||
model.to(dtype=self._torch_dtype)
|
model.to(dtype=self._torch_dtype)
|
||||||
@ -188,7 +188,7 @@ class FluxCheckpointModel(ModelLoader):
|
|||||||
model_path = Path(config.path)
|
model_path = Path(config.path)
|
||||||
|
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
model = Flux(configs[config.config_path].params)
|
model = Flux(params[config.config_path])
|
||||||
sd = load_file(model_path)
|
sd = load_file(model_path)
|
||||||
model.load_state_dict(sd, assign=True)
|
model.load_state_dict(sd, assign=True)
|
||||||
return model
|
return model
|
||||||
@ -227,7 +227,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
|||||||
|
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
with accelerate.init_empty_weights():
|
with accelerate.init_empty_weights():
|
||||||
model = Flux(configs[config.config_path].params)
|
model = Flux(params[config.config_path])
|
||||||
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||||
sd = load_file(model_path)
|
sd = load_file(model_path)
|
||||||
model.load_state_dict(sd, assign=True)
|
model.load_state_dict(sd, assign=True)
|
||||||
|
@ -329,8 +329,16 @@ class ModelProbe(object):
|
|||||||
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
||||||
state_dict = checkpoint.get("state_dict") or checkpoint
|
state_dict = checkpoint.get("state_dict") or checkpoint
|
||||||
if "guidance_in.out_layer.weight" in state_dict:
|
if "guidance_in.out_layer.weight" in state_dict:
|
||||||
|
# For flux, this is a key in invokeai.backend.flux.util.params
|
||||||
|
# Due to model type and format being the descriminator for model configs this
|
||||||
|
# is used rather than attempting to support flux with separate model types and format
|
||||||
|
# If changed in the future, please fix me
|
||||||
config_file = "flux-dev"
|
config_file = "flux-dev"
|
||||||
else:
|
else:
|
||||||
|
# For flux, this is a key in invokeai.backend.flux.util.params
|
||||||
|
# Due to model type and format being the descriminator for model configs this
|
||||||
|
# is used rather than attempting to support flux with separate model types and format
|
||||||
|
# If changed in the future, please fix me
|
||||||
config_file = "flux-schnell"
|
config_file = "flux-schnell"
|
||||||
else:
|
else:
|
||||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||||
@ -345,7 +353,11 @@ class ModelProbe(object):
|
|||||||
)
|
)
|
||||||
elif model_type is ModelType.VAE:
|
elif model_type is ModelType.VAE:
|
||||||
config_file = (
|
config_file = (
|
||||||
"flux/flux1-vae.yaml"
|
# For flux, this is a key in invokeai.backend.flux.util.ae_params
|
||||||
|
# Due to model type and format being the descriminator for model configs this
|
||||||
|
# is used rather than attempting to support flux with separate model types and format
|
||||||
|
# If changed in the future, please fix me
|
||||||
|
"flux"
|
||||||
if base_type is BaseModelType.Flux
|
if base_type is BaseModelType.Flux
|
||||||
else "stable-diffusion/v1-inference.yaml"
|
else "stable-diffusion/v1-inference.yaml"
|
||||||
if base_type is BaseModelType.StableDiffusion1
|
if base_type is BaseModelType.StableDiffusion1
|
||||||
|
@ -4,7 +4,7 @@ import accelerate
|
|||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
|
|
||||||
from invokeai.backend.flux.model import Flux
|
from invokeai.backend.flux.model import Flux
|
||||||
from invokeai.backend.flux.util import configs as flux_configs
|
from invokeai.backend.flux.util import params
|
||||||
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
||||||
from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time
|
from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time
|
||||||
|
|
||||||
@ -22,11 +22,11 @@ def main():
|
|||||||
|
|
||||||
with log_time("Intialize FLUX transformer on meta device"):
|
with log_time("Intialize FLUX transformer on meta device"):
|
||||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||||
params = flux_configs["flux-schnell"].params
|
p = params["flux-schnell"]
|
||||||
|
|
||||||
# Initialize the model on the "meta" device.
|
# Initialize the model on the "meta" device.
|
||||||
with accelerate.init_empty_weights():
|
with accelerate.init_empty_weights():
|
||||||
model = Flux(params)
|
model = Flux(p)
|
||||||
|
|
||||||
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
||||||
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
||||||
|
@ -7,7 +7,7 @@ import torch
|
|||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
|
|
||||||
from invokeai.backend.flux.model import Flux
|
from invokeai.backend.flux.model import Flux
|
||||||
from invokeai.backend.flux.util import configs as flux_configs
|
from invokeai.backend.flux.util import params
|
||||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||||
|
|
||||||
|
|
||||||
@ -35,11 +35,11 @@ def main():
|
|||||||
# inference_dtype = torch.bfloat16
|
# inference_dtype = torch.bfloat16
|
||||||
with log_time("Intialize FLUX transformer on meta device"):
|
with log_time("Intialize FLUX transformer on meta device"):
|
||||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||||
params = flux_configs["flux-schnell"].params
|
p = params["flux-schnell"]
|
||||||
|
|
||||||
# Initialize the model on the "meta" device.
|
# Initialize the model on the "meta" device.
|
||||||
with accelerate.init_empty_weights():
|
with accelerate.init_empty_weights():
|
||||||
model = Flux(params)
|
model = Flux(p)
|
||||||
|
|
||||||
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
||||||
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
||||||
|
Loading…
Reference in New Issue
Block a user