add concept of repo variant

This commit is contained in:
Lincoln Stein 2024-01-22 14:37:23 -05:00 committed by Brandon Rising
parent 55147fbb7e
commit 66e2d1b346
5 changed files with 65 additions and 4 deletions

View File

@ -150,7 +150,7 @@ class _DiffusersConfig(ModelConfigBase):
"""Model config for diffusers-style models.""" """Model config for diffusers-style models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
class LoRAConfig(ModelConfigBase): class LoRAConfig(ModelConfigBase):
"""Model config for LoRA/Lycoris models.""" """Model config for LoRA/Lycoris models."""
@ -179,7 +179,6 @@ class ControlNetDiffusersConfig(_DiffusersConfig):
type: Literal[ModelType.ControlNet] = ModelType.ControlNet type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetCheckpointConfig(_CheckpointConfig): class ControlNetCheckpointConfig(_CheckpointConfig):
"""Model config for ControlNet models (diffusers version).""" """Model config for ControlNet models (diffusers version)."""
@ -215,7 +214,6 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False upcast_attention: bool = False
class ONNXSD1Config(_MainConfig): class ONNXSD1Config(_MainConfig):
"""Model config for ONNX format models based on sd-1.""" """Model config for ONNX format models based on sd-1."""

View File

@ -20,6 +20,7 @@ from .config import (
ModelFormat, ModelFormat,
ModelType, ModelType,
ModelVariantType, ModelVariantType,
ModelRepoVariant,
SchedulerPredictionType, SchedulerPredictionType,
) )
from .hash import FastModelHash from .hash import FastModelHash
@ -155,6 +156,9 @@ class ModelProbe(object):
fields["original_hash"] = fields.get("original_hash") or hash fields["original_hash"] = fields.get("original_hash") or hash
fields["current_hash"] = fields.get("current_hash") or hash fields["current_hash"] = fields.get("current_hash") or hash
if format_type == ModelFormat.Diffusers:
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
# additional fields needed for main and controlnet models # additional fields needed for main and controlnet models
if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint: if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint:
fields["config"] = cls._get_checkpoint_config_path( fields["config"] = cls._get_checkpoint_config_path(
@ -477,6 +481,20 @@ class FolderProbeBase(ProbeBase):
def get_format(self) -> ModelFormat: def get_format(self) -> ModelFormat:
return ModelFormat("diffusers") return ModelFormat("diffusers")
def get_repo_variant(self) -> ModelRepoVariant:
# get all files ending in .bin or .safetensors
weight_files = list(self.model_path.glob('**/*.safetensors'))
weight_files.extend(list(self.model_path.glob('**/*.bin')))
for x in weight_files:
if ".fp16" in x.suffixes:
return ModelRepoVariant.FP16
if "openvino_model" in x.name:
return ModelRepoVariant.OPENVINO
if "flax_model" in x.name:
return ModelRepoVariant.FLAX
if x.suffix == ".onnx":
return ModelRepoVariant.ONNX
return ModelRepoVariant.DEFAULT
class PipelineFolderProbe(FolderProbeBase): class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType: def get_base_type(self) -> BaseModelType:
@ -524,6 +542,7 @@ class PipelineFolderProbe(FolderProbeBase):
return ModelVariantType.Normal return ModelVariantType.Normal
class VaeFolderProbe(FolderProbeBase): class VaeFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType: def get_base_type(self) -> BaseModelType:
if self._config_looks_like_sdxl(): if self._config_looks_like_sdxl():

View File

@ -3,7 +3,7 @@ from pathlib import Path
import pytest import pytest
from invokeai.backend import BaseModelType from invokeai.backend import BaseModelType
from invokeai.backend.model_management.model_probe import VaeFolderProbe from invokeai.backend.model_manager.probe import VaeFolderProbe
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -20,3 +20,10 @@ def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Pat
probe = VaeFolderProbe(sd1_vae_path) probe = VaeFolderProbe(sd1_vae_path)
base_type = probe.get_base_type() base_type = probe.get_base_type()
assert base_type == expected_type assert base_type == expected_type
repo_variant = probe.get_repo_variant()
assert repo_variant == 'default'
def test_repo_variant(datadir: Path):
probe = VaeFolderProbe(datadir / "vae" / "taesdxl-fp16")
repo_variant = probe.get_repo_variant()
assert repo_variant == 'fp16'

View File

@ -0,0 +1,37 @@
{
"_class_name": "AutoencoderTiny",
"_diffusers_version": "0.20.0.dev0",
"act_fn": "relu",
"decoder_block_out_channels": [
64,
64,
64,
64
],
"encoder_block_out_channels": [
64,
64,
64,
64
],
"force_upcast": false,
"in_channels": 3,
"latent_channels": 4,
"latent_magnitude": 3,
"latent_shift": 0.5,
"num_decoder_blocks": [
3,
3,
3,
1
],
"num_encoder_blocks": [
1,
3,
3,
3
],
"out_channels": 3,
"scaling_factor": 1.0,
"upsampling_scaling_factor": 2
}