mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add concept of repo variant
This commit is contained in:
committed by
psychedelicious
parent
f505ec64ba
commit
a1307b9f2e
@ -150,7 +150,7 @@ class _DiffusersConfig(ModelConfigBase):
|
||||
"""Model config for diffusers-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
|
||||
|
||||
class LoRAConfig(ModelConfigBase):
|
||||
"""Model config for LoRA/Lycoris models."""
|
||||
@ -179,7 +179,6 @@ class ControlNetDiffusersConfig(_DiffusersConfig):
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
|
||||
class ControlNetCheckpointConfig(_CheckpointConfig):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
@ -215,7 +214,6 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
|
||||
class ONNXSD1Config(_MainConfig):
|
||||
"""Model config for ONNX format models based on sd-1."""
|
||||
|
||||
|
@ -20,6 +20,7 @@ from .config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
ModelRepoVariant,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from .hash import FastModelHash
|
||||
@ -155,6 +156,9 @@ class ModelProbe(object):
|
||||
fields["original_hash"] = fields.get("original_hash") or hash
|
||||
fields["current_hash"] = fields.get("current_hash") or hash
|
||||
|
||||
if format_type == ModelFormat.Diffusers:
|
||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||
|
||||
# additional fields needed for main and controlnet models
|
||||
if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint:
|
||||
fields["config"] = cls._get_checkpoint_config_path(
|
||||
@ -477,6 +481,20 @@ class FolderProbeBase(ProbeBase):
|
||||
def get_format(self) -> ModelFormat:
|
||||
return ModelFormat("diffusers")
|
||||
|
||||
def get_repo_variant(self) -> ModelRepoVariant:
|
||||
# get all files ending in .bin or .safetensors
|
||||
weight_files = list(self.model_path.glob('**/*.safetensors'))
|
||||
weight_files.extend(list(self.model_path.glob('**/*.bin')))
|
||||
for x in weight_files:
|
||||
if ".fp16" in x.suffixes:
|
||||
return ModelRepoVariant.FP16
|
||||
if "openvino_model" in x.name:
|
||||
return ModelRepoVariant.OPENVINO
|
||||
if "flax_model" in x.name:
|
||||
return ModelRepoVariant.FLAX
|
||||
if x.suffix == ".onnx":
|
||||
return ModelRepoVariant.ONNX
|
||||
return ModelRepoVariant.DEFAULT
|
||||
|
||||
class PipelineFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
@ -522,6 +540,7 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
except Exception:
|
||||
pass
|
||||
return ModelVariantType.Normal
|
||||
|
||||
|
||||
|
||||
class VaeFolderProbe(FolderProbeBase):
|
||||
|
Reference in New Issue
Block a user