mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add concept of repo variant
This commit is contained in:
parent
6a2856e46f
commit
6b8a6e12bc
@ -150,7 +150,7 @@ class _DiffusersConfig(ModelConfigBase):
|
|||||||
"""Model config for diffusers-style models."""
|
"""Model config for diffusers-style models."""
|
||||||
|
|
||||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||||
|
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
|
||||||
|
|
||||||
class LoRAConfig(ModelConfigBase):
|
class LoRAConfig(ModelConfigBase):
|
||||||
"""Model config for LoRA/Lycoris models."""
|
"""Model config for LoRA/Lycoris models."""
|
||||||
@ -179,7 +179,6 @@ class ControlNetDiffusersConfig(_DiffusersConfig):
|
|||||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||||
|
|
||||||
|
|
||||||
class ControlNetCheckpointConfig(_CheckpointConfig):
|
class ControlNetCheckpointConfig(_CheckpointConfig):
|
||||||
"""Model config for ControlNet models (diffusers version)."""
|
"""Model config for ControlNet models (diffusers version)."""
|
||||||
|
|
||||||
@ -215,7 +214,6 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
|||||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||||
upcast_attention: bool = False
|
upcast_attention: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ONNXSD1Config(_MainConfig):
|
class ONNXSD1Config(_MainConfig):
|
||||||
"""Model config for ONNX format models based on sd-1."""
|
"""Model config for ONNX format models based on sd-1."""
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ from .config import (
|
|||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelType,
|
ModelType,
|
||||||
ModelVariantType,
|
ModelVariantType,
|
||||||
|
ModelRepoVariant,
|
||||||
SchedulerPredictionType,
|
SchedulerPredictionType,
|
||||||
)
|
)
|
||||||
from .hash import FastModelHash
|
from .hash import FastModelHash
|
||||||
@ -155,6 +156,9 @@ class ModelProbe(object):
|
|||||||
fields["original_hash"] = fields.get("original_hash") or hash
|
fields["original_hash"] = fields.get("original_hash") or hash
|
||||||
fields["current_hash"] = fields.get("current_hash") or hash
|
fields["current_hash"] = fields.get("current_hash") or hash
|
||||||
|
|
||||||
|
if format_type == ModelFormat.Diffusers:
|
||||||
|
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||||
|
|
||||||
# additional fields needed for main and controlnet models
|
# additional fields needed for main and controlnet models
|
||||||
if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint:
|
if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint:
|
||||||
fields["config"] = cls._get_checkpoint_config_path(
|
fields["config"] = cls._get_checkpoint_config_path(
|
||||||
@ -477,6 +481,20 @@ class FolderProbeBase(ProbeBase):
|
|||||||
def get_format(self) -> ModelFormat:
|
def get_format(self) -> ModelFormat:
|
||||||
return ModelFormat("diffusers")
|
return ModelFormat("diffusers")
|
||||||
|
|
||||||
|
def get_repo_variant(self) -> ModelRepoVariant:
|
||||||
|
# get all files ending in .bin or .safetensors
|
||||||
|
weight_files = list(self.model_path.glob('**/*.safetensors'))
|
||||||
|
weight_files.extend(list(self.model_path.glob('**/*.bin')))
|
||||||
|
for x in weight_files:
|
||||||
|
if ".fp16" in x.suffixes:
|
||||||
|
return ModelRepoVariant.FP16
|
||||||
|
if "openvino_model" in x.name:
|
||||||
|
return ModelRepoVariant.OPENVINO
|
||||||
|
if "flax_model" in x.name:
|
||||||
|
return ModelRepoVariant.FLAX
|
||||||
|
if x.suffix == ".onnx":
|
||||||
|
return ModelRepoVariant.ONNX
|
||||||
|
return ModelRepoVariant.DEFAULT
|
||||||
|
|
||||||
class PipelineFolderProbe(FolderProbeBase):
|
class PipelineFolderProbe(FolderProbeBase):
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
@ -522,6 +540,7 @@ class PipelineFolderProbe(FolderProbeBase):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return ModelVariantType.Normal
|
return ModelVariantType.Normal
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class VaeFolderProbe(FolderProbeBase):
|
class VaeFolderProbe(FolderProbeBase):
|
||||||
|
@ -3,7 +3,7 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from invokeai.backend import BaseModelType
|
from invokeai.backend import BaseModelType
|
||||||
from invokeai.backend.model_management.model_probe import VaeFolderProbe
|
from invokeai.backend.model_manager.probe import VaeFolderProbe
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -20,3 +20,10 @@ def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Pat
|
|||||||
probe = VaeFolderProbe(sd1_vae_path)
|
probe = VaeFolderProbe(sd1_vae_path)
|
||||||
base_type = probe.get_base_type()
|
base_type = probe.get_base_type()
|
||||||
assert base_type == expected_type
|
assert base_type == expected_type
|
||||||
|
repo_variant = probe.get_repo_variant()
|
||||||
|
assert repo_variant == 'default'
|
||||||
|
|
||||||
|
def test_repo_variant(datadir: Path):
|
||||||
|
probe = VaeFolderProbe(datadir / "vae" / "taesdxl-fp16")
|
||||||
|
repo_variant = probe.get_repo_variant()
|
||||||
|
assert repo_variant == 'fp16'
|
||||||
|
37
tests/test_model_probe/vae/taesdxl-fp16/config.json
Normal file
37
tests/test_model_probe/vae/taesdxl-fp16/config.json
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
{
|
||||||
|
"_class_name": "AutoencoderTiny",
|
||||||
|
"_diffusers_version": "0.20.0.dev0",
|
||||||
|
"act_fn": "relu",
|
||||||
|
"decoder_block_out_channels": [
|
||||||
|
64,
|
||||||
|
64,
|
||||||
|
64,
|
||||||
|
64
|
||||||
|
],
|
||||||
|
"encoder_block_out_channels": [
|
||||||
|
64,
|
||||||
|
64,
|
||||||
|
64,
|
||||||
|
64
|
||||||
|
],
|
||||||
|
"force_upcast": false,
|
||||||
|
"in_channels": 3,
|
||||||
|
"latent_channels": 4,
|
||||||
|
"latent_magnitude": 3,
|
||||||
|
"latent_shift": 0.5,
|
||||||
|
"num_decoder_blocks": [
|
||||||
|
3,
|
||||||
|
3,
|
||||||
|
3,
|
||||||
|
1
|
||||||
|
],
|
||||||
|
"num_encoder_blocks": [
|
||||||
|
1,
|
||||||
|
3,
|
||||||
|
3,
|
||||||
|
3
|
||||||
|
],
|
||||||
|
"out_channels": 3,
|
||||||
|
"scaling_factor": 1.0,
|
||||||
|
"upsampling_scaling_factor": 2
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user