add concept of repo variant

This commit is contained in:
Lincoln Stein
2024-01-22 14:37:23 -05:00
committed by psychedelicious
parent f505ec64ba
commit a1307b9f2e
5 changed files with 65 additions and 4 deletions

View File

@ -150,7 +150,7 @@ class _DiffusersConfig(ModelConfigBase):
"""Model config for diffusers-style models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
class LoRAConfig(ModelConfigBase):
"""Model config for LoRA/Lycoris models."""
@ -179,7 +179,6 @@ class ControlNetDiffusersConfig(_DiffusersConfig):
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetCheckpointConfig(_CheckpointConfig):
"""Model config for ControlNet models (diffusers version)."""
@ -215,7 +214,6 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class ONNXSD1Config(_MainConfig):
"""Model config for ONNX format models based on sd-1."""

View File

@ -20,6 +20,7 @@ from .config import (
ModelFormat,
ModelType,
ModelVariantType,
ModelRepoVariant,
SchedulerPredictionType,
)
from .hash import FastModelHash
@ -155,6 +156,9 @@ class ModelProbe(object):
fields["original_hash"] = fields.get("original_hash") or hash
fields["current_hash"] = fields.get("current_hash") or hash
if format_type == ModelFormat.Diffusers:
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
# additional fields needed for main and controlnet models
if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint:
fields["config"] = cls._get_checkpoint_config_path(
@ -477,6 +481,20 @@ class FolderProbeBase(ProbeBase):
def get_format(self) -> ModelFormat:
return ModelFormat("diffusers")
def get_repo_variant(self) -> ModelRepoVariant:
# get all files ending in .bin or .safetensors
weight_files = list(self.model_path.glob('**/*.safetensors'))
weight_files.extend(list(self.model_path.glob('**/*.bin')))
for x in weight_files:
if ".fp16" in x.suffixes:
return ModelRepoVariant.FP16
if "openvino_model" in x.name:
return ModelRepoVariant.OPENVINO
if "flax_model" in x.name:
return ModelRepoVariant.FLAX
if x.suffix == ".onnx":
return ModelRepoVariant.ONNX
return ModelRepoVariant.DEFAULT
class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
@ -522,6 +540,7 @@ class PipelineFolderProbe(FolderProbeBase):
except Exception:
pass
return ModelVariantType.Normal
class VaeFolderProbe(FolderProbeBase):