return correct base type for sd3 VAEs

This commit is contained in:
Lincoln Stein 2024-06-15 18:17:03 -04:00
parent ac0396e6f7
commit 554809c647

View File

@ -650,6 +650,8 @@ class VaeFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
if self._config_looks_like_sdxl():
return BaseModelType.StableDiffusionXL
elif self._config_looks_like_sd3():
return BaseModelType.StableDiffusion3
elif self._name_looks_like_sdxl():
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
@ -669,6 +671,15 @@ class VaeFolderProbe(FolderProbeBase):
def _name_looks_like_sdxl(self) -> bool:
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
def _config_looks_like_sd3(self) -> bool:
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
config_file = self.model_path / "config.json"
if not config_file.exists():
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
with open(config_file, "r") as file:
config = json.load(file)
return config.get("scaling_factor", 0) == 1.5305 and config.get("sample_size") in [512, 1024]
def _guess_name(self) -> str:
name = self.model_path.name
if name == "vae":