mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
return correct base type for sd3 VAEs
This commit is contained in:
parent
ac0396e6f7
commit
554809c647
@ -650,6 +650,8 @@ class VaeFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
if self._config_looks_like_sdxl():
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif self._config_looks_like_sd3():
|
||||
return BaseModelType.StableDiffusion3
|
||||
elif self._name_looks_like_sdxl():
|
||||
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
|
||||
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
|
||||
@ -669,6 +671,15 @@ class VaeFolderProbe(FolderProbeBase):
|
||||
def _name_looks_like_sdxl(self) -> bool:
|
||||
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
|
||||
|
||||
def _config_looks_like_sd3(self) -> bool:
|
||||
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
|
||||
config_file = self.model_path / "config.json"
|
||||
if not config_file.exists():
|
||||
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
|
||||
with open(config_file, "r") as file:
|
||||
config = json.load(file)
|
||||
return config.get("scaling_factor", 0) == 1.5305 and config.get("sample_size") in [512, 1024]
|
||||
|
||||
def _guess_name(self) -> str:
|
||||
name = self.model_path.name
|
||||
if name == "vae":
|
||||
|
Loading…
Reference in New Issue
Block a user