mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Select dev/schnell based on state dict, use correct max seq len based on dev/schnell, and shift in inference, separate vae flux params into separate config
This commit is contained in:
@ -32,7 +32,6 @@ from invokeai.backend.model_manager.config import (
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||
|
||||
@ -60,7 +59,7 @@ class FluxVAELoader(GenericDiffusersLoader):
|
||||
raise
|
||||
|
||||
dataclass_fields = {f.name for f in fields(AutoEncoderParams)}
|
||||
filtered_data = {k: v for k, v in flux_conf["params"]["ae_params"].items() if k in dataclass_fields}
|
||||
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
|
||||
params = AutoEncoderParams(**filtered_data)
|
||||
|
||||
with SilenceWarnings():
|
||||
|
@ -324,7 +324,12 @@ class ModelProbe(object):
|
||||
if model_type is ModelType.Main:
|
||||
if base_type == BaseModelType.Flux:
|
||||
# TODO: Decide between dev/schnell
|
||||
config_file = "flux/flux1-schnell.yaml"
|
||||
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
||||
state_dict = checkpoint.get("state_dict") or checkpoint
|
||||
if 'guidance_in.out_layer.weight' in state_dict:
|
||||
config_file = "flux/flux1-dev.yaml"
|
||||
else:
|
||||
config_file = "flux/flux1-schnell.yaml"
|
||||
else:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
@ -338,7 +343,7 @@ class ModelProbe(object):
|
||||
)
|
||||
elif model_type is ModelType.VAE:
|
||||
config_file = (
|
||||
"flux/flux1-schnell.yaml"
|
||||
"flux/flux1-vae.yaml"
|
||||
if base_type is BaseModelType.Flux
|
||||
else "stable-diffusion/v1-inference.yaml"
|
||||
if base_type is BaseModelType.StableDiffusion1
|
||||
|
Reference in New Issue
Block a user