mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(model management): guess whether a VAE is for SDXL based on its name
This commit is contained in:
parent
f222b871e9
commit
e0f8274f49
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, Literal, Optional, Union
|
from typing import Callable, Dict, Literal, Optional, Union
|
||||||
@ -469,16 +470,32 @@ class PipelineFolderProbe(FolderProbeBase):
|
|||||||
|
|
||||||
class VaeFolderProbe(FolderProbeBase):
|
class VaeFolderProbe(FolderProbeBase):
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
if self._config_looks_like_sdxl():
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
elif self._name_looks_like_sdxl():
|
||||||
|
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
|
||||||
|
# by a factor of 8), necessarily tell them apart by config hyperparameters.
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
else:
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
|
||||||
|
def _config_looks_like_sdxl(self) -> bool:
|
||||||
|
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
|
||||||
config_file = self.folder_path / "config.json"
|
config_file = self.folder_path / "config.json"
|
||||||
if not config_file.exists():
|
if not config_file.exists():
|
||||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
||||||
with open(config_file, "r") as file:
|
with open(config_file, "r") as file:
|
||||||
config = json.load(file)
|
config = json.load(file)
|
||||||
return (
|
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
||||||
BaseModelType.StableDiffusionXL
|
|
||||||
if config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
def _name_looks_like_sdxl(self) -> bool:
|
||||||
else BaseModelType.StableDiffusion1
|
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
|
||||||
)
|
|
||||||
|
def _guess_name(self) -> str:
|
||||||
|
name = self.folder_path.name
|
||||||
|
if name == "vae":
|
||||||
|
name = self.folder_path.parent.name
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionFolderProbe(FolderProbeBase):
|
class TextualInversionFolderProbe(FolderProbeBase):
|
||||||
|
Loading…
Reference in New Issue
Block a user