mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Switch inheritance class of flux model loaders
This commit is contained in:
parent
d4872253a1
commit
0913d062d8
@ -31,8 +31,8 @@ from invokeai.backend.model_manager.config import (
|
|||||||
T5EncoderConfig,
|
T5EncoderConfig,
|
||||||
VAECheckpointConfig,
|
VAECheckpointConfig,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
|
||||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||||
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
|
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
@ -41,7 +41,7 @@ app_config = get_config()
|
|||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||||
class FluxVAELoader(GenericDiffusersLoader):
|
class FluxVAELoader(ModelLoader):
|
||||||
"""Class to load VAE models."""
|
"""Class to load VAE models."""
|
||||||
|
|
||||||
def _load_model(
|
def _load_model(
|
||||||
@ -75,7 +75,7 @@ class FluxVAELoader(GenericDiffusersLoader):
|
|||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers)
|
||||||
class ClipCheckpointModel(GenericDiffusersLoader):
|
class ClipCheckpointModel(ModelLoader):
|
||||||
"""Class to load main models."""
|
"""Class to load main models."""
|
||||||
|
|
||||||
def _load_model(
|
def _load_model(
|
||||||
@ -96,7 +96,7 @@ class ClipCheckpointModel(GenericDiffusersLoader):
|
|||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder8b)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder8b)
|
||||||
class T5Encoder8bCheckpointModel(GenericDiffusersLoader):
|
class T5Encoder8bCheckpointModel(ModelLoader):
|
||||||
"""Class to load main models."""
|
"""Class to load main models."""
|
||||||
|
|
||||||
def _load_model(
|
def _load_model(
|
||||||
@ -117,7 +117,7 @@ class T5Encoder8bCheckpointModel(GenericDiffusersLoader):
|
|||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
|
||||||
class T5EncoderCheckpointModel(GenericDiffusersLoader):
|
class T5EncoderCheckpointModel(ModelLoader):
|
||||||
"""Class to load main models."""
|
"""Class to load main models."""
|
||||||
|
|
||||||
def _load_model(
|
def _load_model(
|
||||||
@ -140,7 +140,7 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader):
|
|||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||||
class FluxCheckpointModel(GenericDiffusersLoader):
|
class FluxCheckpointModel(ModelLoader):
|
||||||
"""Class to load main models."""
|
"""Class to load main models."""
|
||||||
|
|
||||||
def _load_model(
|
def _load_model(
|
||||||
@ -185,7 +185,7 @@ class FluxCheckpointModel(GenericDiffusersLoader):
|
|||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b)
|
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b)
|
||||||
class FluxBnbQuantizednf4bCheckpointModel(GenericDiffusersLoader):
|
class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
||||||
"""Class to load main models."""
|
"""Class to load main models."""
|
||||||
|
|
||||||
def _load_model(
|
def _load_model(
|
||||||
|
Loading…
Reference in New Issue
Block a user