feat(model management): guess whether a VAE is for SDXL based on its name

This commit is contained in:
Kevin Turner 2023-09-20 12:03:15 -07:00
parent f222b871e9
commit e0f8274f49

View File

@ -1,4 +1,5 @@
import json import json
import re
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, Literal, Optional, Union from typing import Callable, Dict, Literal, Optional, Union
@ -469,16 +470,32 @@ class PipelineFolderProbe(FolderProbeBase):
class VaeFolderProbe(FolderProbeBase): class VaeFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType: def get_base_type(self) -> BaseModelType:
if self._config_looks_like_sdxl():
return BaseModelType.StableDiffusionXL
elif self._name_looks_like_sdxl():
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
# by a factor of 8), necessarily tell them apart by config hyperparameters.
return BaseModelType.StableDiffusionXL
else:
return BaseModelType.StableDiffusion1
def _config_looks_like_sdxl(self) -> bool:
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
config_file = self.folder_path / "config.json" config_file = self.folder_path / "config.json"
if not config_file.exists(): if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}") raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
with open(config_file, "r") as file: with open(config_file, "r") as file:
config = json.load(file) config = json.load(file)
return ( return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
BaseModelType.StableDiffusionXL
if config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024] def _name_looks_like_sdxl(self) -> bool:
else BaseModelType.StableDiffusion1 return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
)
def _guess_name(self) -> str:
name = self.folder_path.name
if name == "vae":
name = self.folder_path.parent.name
return name
class TextualInversionFolderProbe(FolderProbeBase): class TextualInversionFolderProbe(FolderProbeBase):