Switch inheritance class of flux model loaders

This commit is contained in:
Brandon Rising 2024-08-21 11:30:16 -04:00
parent d4872253a1
commit 0913d062d8

View File

@ -31,8 +31,8 @@ from invokeai.backend.model_manager.config import (
T5EncoderConfig, T5EncoderConfig,
VAECheckpointConfig, VAECheckpointConfig,
) )
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
from invokeai.backend.util.silence_warnings import SilenceWarnings from invokeai.backend.util.silence_warnings import SilenceWarnings
@ -41,7 +41,7 @@ app_config = get_config()
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.VAE, format=ModelFormat.Checkpoint) @ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.VAE, format=ModelFormat.Checkpoint)
class FluxVAELoader(GenericDiffusersLoader): class FluxVAELoader(ModelLoader):
"""Class to load VAE models.""" """Class to load VAE models."""
def _load_model( def _load_model(
@ -75,7 +75,7 @@ class FluxVAELoader(GenericDiffusersLoader):
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers)
class ClipCheckpointModel(GenericDiffusersLoader): class ClipCheckpointModel(ModelLoader):
"""Class to load main models.""" """Class to load main models."""
def _load_model( def _load_model(
@ -96,7 +96,7 @@ class ClipCheckpointModel(GenericDiffusersLoader):
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder8b) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder8b)
class T5Encoder8bCheckpointModel(GenericDiffusersLoader): class T5Encoder8bCheckpointModel(ModelLoader):
"""Class to load main models.""" """Class to load main models."""
def _load_model( def _load_model(
@ -117,7 +117,7 @@ class T5Encoder8bCheckpointModel(GenericDiffusersLoader):
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
class T5EncoderCheckpointModel(GenericDiffusersLoader): class T5EncoderCheckpointModel(ModelLoader):
"""Class to load main models.""" """Class to load main models."""
def _load_model( def _load_model(
@ -140,7 +140,7 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader):
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint) @ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
class FluxCheckpointModel(GenericDiffusersLoader): class FluxCheckpointModel(ModelLoader):
"""Class to load main models.""" """Class to load main models."""
def _load_model( def _load_model(
@ -185,7 +185,7 @@ class FluxCheckpointModel(GenericDiffusersLoader):
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b) @ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b)
class FluxBnbQuantizednf4bCheckpointModel(GenericDiffusersLoader): class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
"""Class to load main models.""" """Class to load main models."""
def _load_model( def _load_model(