Switch inheritance class of flux model loaders

This commit is contained in:
Brandon Rising 2024-08-21 11:30:16 -04:00 committed by Brandon
parent f7e46622a1
commit 87b7a2e39b

View File

@ -31,8 +31,8 @@ from invokeai.backend.model_manager.config import (
T5EncoderConfig,
VAECheckpointConfig,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
from invokeai.backend.util.silence_warnings import SilenceWarnings
@ -41,7 +41,7 @@ app_config = get_config()
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.VAE, format=ModelFormat.Checkpoint)
class FluxVAELoader(GenericDiffusersLoader):
class FluxVAELoader(ModelLoader):
"""Class to load VAE models."""
def _load_model(
@ -75,7 +75,7 @@ class FluxVAELoader(GenericDiffusersLoader):
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers)
class ClipCheckpointModel(GenericDiffusersLoader):
class ClipCheckpointModel(ModelLoader):
"""Class to load main models."""
def _load_model(
@ -96,7 +96,7 @@ class ClipCheckpointModel(GenericDiffusersLoader):
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder8b)
class T5Encoder8bCheckpointModel(GenericDiffusersLoader):
class T5Encoder8bCheckpointModel(ModelLoader):
"""Class to load main models."""
def _load_model(
@ -117,7 +117,7 @@ class T5Encoder8bCheckpointModel(GenericDiffusersLoader):
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
class T5EncoderCheckpointModel(GenericDiffusersLoader):
class T5EncoderCheckpointModel(ModelLoader):
"""Class to load main models."""
def _load_model(
@ -140,7 +140,7 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader):
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
class FluxCheckpointModel(GenericDiffusersLoader):
class FluxCheckpointModel(ModelLoader):
"""Class to load main models."""
def _load_model(
@ -185,7 +185,7 @@ class FluxCheckpointModel(GenericDiffusersLoader):
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b)
class FluxBnbQuantizednf4bCheckpointModel(GenericDiffusersLoader):
class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
"""Class to load main models."""
def _load_model(