mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add concept of repo variant
This commit is contained in:
parent
6a2856e46f
commit
6b8a6e12bc
@ -150,7 +150,7 @@ class _DiffusersConfig(ModelConfigBase):
|
||||
"""Model config for diffusers-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
|
||||
|
||||
class LoRAConfig(ModelConfigBase):
|
||||
"""Model config for LoRA/Lycoris models."""
|
||||
@ -179,7 +179,6 @@ class ControlNetDiffusersConfig(_DiffusersConfig):
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
|
||||
class ControlNetCheckpointConfig(_CheckpointConfig):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
@ -215,7 +214,6 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
|
||||
class ONNXSD1Config(_MainConfig):
|
||||
"""Model config for ONNX format models based on sd-1."""
|
||||
|
||||
|
@ -20,6 +20,7 @@ from .config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
ModelRepoVariant,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from .hash import FastModelHash
|
||||
@ -155,6 +156,9 @@ class ModelProbe(object):
|
||||
fields["original_hash"] = fields.get("original_hash") or hash
|
||||
fields["current_hash"] = fields.get("current_hash") or hash
|
||||
|
||||
if format_type == ModelFormat.Diffusers:
|
||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||
|
||||
# additional fields needed for main and controlnet models
|
||||
if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint:
|
||||
fields["config"] = cls._get_checkpoint_config_path(
|
||||
@ -477,6 +481,20 @@ class FolderProbeBase(ProbeBase):
|
||||
def get_format(self) -> ModelFormat:
|
||||
return ModelFormat("diffusers")
|
||||
|
||||
def get_repo_variant(self) -> ModelRepoVariant:
|
||||
# get all files ending in .bin or .safetensors
|
||||
weight_files = list(self.model_path.glob('**/*.safetensors'))
|
||||
weight_files.extend(list(self.model_path.glob('**/*.bin')))
|
||||
for x in weight_files:
|
||||
if ".fp16" in x.suffixes:
|
||||
return ModelRepoVariant.FP16
|
||||
if "openvino_model" in x.name:
|
||||
return ModelRepoVariant.OPENVINO
|
||||
if "flax_model" in x.name:
|
||||
return ModelRepoVariant.FLAX
|
||||
if x.suffix == ".onnx":
|
||||
return ModelRepoVariant.ONNX
|
||||
return ModelRepoVariant.DEFAULT
|
||||
|
||||
class PipelineFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
@ -524,6 +542,7 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
return ModelVariantType.Normal
|
||||
|
||||
|
||||
|
||||
class VaeFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
if self._config_looks_like_sdxl():
|
||||
|
@ -3,7 +3,7 @@ from pathlib import Path
|
||||
import pytest
|
||||
|
||||
from invokeai.backend import BaseModelType
|
||||
from invokeai.backend.model_management.model_probe import VaeFolderProbe
|
||||
from invokeai.backend.model_manager.probe import VaeFolderProbe
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -20,3 +20,10 @@ def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Pat
|
||||
probe = VaeFolderProbe(sd1_vae_path)
|
||||
base_type = probe.get_base_type()
|
||||
assert base_type == expected_type
|
||||
repo_variant = probe.get_repo_variant()
|
||||
assert repo_variant == 'default'
|
||||
|
||||
def test_repo_variant(datadir: Path):
|
||||
probe = VaeFolderProbe(datadir / "vae" / "taesdxl-fp16")
|
||||
repo_variant = probe.get_repo_variant()
|
||||
assert repo_variant == 'fp16'
|
||||
|
37
tests/test_model_probe/vae/taesdxl-fp16/config.json
Normal file
37
tests/test_model_probe/vae/taesdxl-fp16/config.json
Normal file
@ -0,0 +1,37 @@
|
||||
{
|
||||
"_class_name": "AutoencoderTiny",
|
||||
"_diffusers_version": "0.20.0.dev0",
|
||||
"act_fn": "relu",
|
||||
"decoder_block_out_channels": [
|
||||
64,
|
||||
64,
|
||||
64,
|
||||
64
|
||||
],
|
||||
"encoder_block_out_channels": [
|
||||
64,
|
||||
64,
|
||||
64,
|
||||
64
|
||||
],
|
||||
"force_upcast": false,
|
||||
"in_channels": 3,
|
||||
"latent_channels": 4,
|
||||
"latent_magnitude": 3,
|
||||
"latent_shift": 0.5,
|
||||
"num_decoder_blocks": [
|
||||
3,
|
||||
3,
|
||||
3,
|
||||
1
|
||||
],
|
||||
"num_encoder_blocks": [
|
||||
1,
|
||||
3,
|
||||
3,
|
||||
3
|
||||
],
|
||||
"out_channels": 3,
|
||||
"scaling_factor": 1.0,
|
||||
"upsampling_scaling_factor": 2
|
||||
}
|
Loading…
Reference in New Issue
Block a user