From 0913d062d87f5cfca19ef15fba240fb505c3de6d Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Wed, 21 Aug 2024 11:30:16 -0400 Subject: [PATCH] Switch inheritance class of flux model loaders --- .../model_manager/load/model_loaders/flux.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index 6502339a24..44444092e9 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -31,8 +31,8 @@ from invokeai.backend.model_manager.config import ( T5EncoderConfig, VAECheckpointConfig, ) +from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry -from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel from invokeai.backend.util.silence_warnings import SilenceWarnings @@ -41,7 +41,7 @@ app_config = get_config() @ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.VAE, format=ModelFormat.Checkpoint) -class FluxVAELoader(GenericDiffusersLoader): +class FluxVAELoader(ModelLoader): """Class to load VAE models.""" def _load_model( @@ -75,7 +75,7 @@ class FluxVAELoader(GenericDiffusersLoader): @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers) -class ClipCheckpointModel(GenericDiffusersLoader): +class ClipCheckpointModel(ModelLoader): """Class to load main models.""" def _load_model( @@ -96,7 +96,7 @@ class ClipCheckpointModel(GenericDiffusersLoader): @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder8b) -class T5Encoder8bCheckpointModel(GenericDiffusersLoader): +class T5Encoder8bCheckpointModel(ModelLoader): """Class to load main models.""" def _load_model( @@ -117,7 +117,7 @@ class T5Encoder8bCheckpointModel(GenericDiffusersLoader): @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder) -class T5EncoderCheckpointModel(GenericDiffusersLoader): +class T5EncoderCheckpointModel(ModelLoader): """Class to load main models.""" def _load_model( @@ -140,7 +140,7 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader): @ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint) -class FluxCheckpointModel(GenericDiffusersLoader): +class FluxCheckpointModel(ModelLoader): """Class to load main models.""" def _load_model( @@ -185,7 +185,7 @@ class FluxCheckpointModel(GenericDiffusersLoader): @ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b) -class FluxBnbQuantizednf4bCheckpointModel(GenericDiffusersLoader): +class FluxBnbQuantizednf4bCheckpointModel(ModelLoader): """Class to load main models.""" def _load_model(