simplify mapping of detectors to controlnet processors; tweaked download cache name algorithm

This commit is contained in:
Lincoln Stein 2024-07-20 18:34:57 -04:00
parent 1e357bd21b
commit a40fa8e83b
3 changed files with 21 additions and 8 deletions

View File

@ -139,11 +139,11 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
image: ImageField = InputField(description="The image to process")
# Map controlnet_aux detector classes to model files in "lllyasviel/Annotators"
CONTROLNET_PROCESSORS: ClassVar[Dict[type, Path]] = {
MidasDetector: Path("dpt_hybrid-midas-501f0c75.pt"),
MLSDdetector: Path("mlsd_large_512_fp32.pth"),
PidiNetDetector: Path("table5_pidinet.pth"),
ZoeDetector: Path("ZoeD_M12_N.pt"),
CONTROLNET_PROCESSORS: ClassVar[Dict[type, str]] = {
MidasDetector: "lllyasviel/Annotators::/dpt_hybrid-midas-501f0c75.pt",
MLSDdetector: "lllyasviel/Annotators::/mlsd_large_512_fp32.pth",
PidiNetDetector: "lllyasviel/Annotators::/table5_pidinet.pth",
ZoeDetector: "lllyasviel/Annotators::/ZoeD_M12_N.pt",
}
def run_processor(self, image: Image.Image) -> Image.Image:
@ -155,7 +155,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
return context.images.get_pil(self.image.image_name, "RGB")
def load_processor(self, processor: type) -> LoadedModelWithoutConfig:
remote_source = f"lllyasviel/Annotators::/{self.CONTROLNET_PROCESSORS[processor]}"
remote_source = self.CONTROLNET_PROCESSORS[processor]
assert hasattr(processor, "from_pretrained") # no common base class for the controlnet processors!
model = self._context.models.load_remote_model(
source=remote_source, loader=lambda x: processor.from_pretrained(x.parent, filename=x.name)

View File

@ -26,7 +26,7 @@ def slugify(value: str, allow_unicode: bool = False) -> str:
value = unicodedata.normalize("NFKC", value)
else:
value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii")
value = re.sub(r"[/]", "_", value.lower())
value = re.sub(r"[/:]+", "_", value.lower())
value = re.sub(r"[^.\w\s-]", "", value.lower())
return re.sub(r"[-\s]+", "-", value).strip("-_")

View File

@ -81,8 +81,21 @@ def test_download_diffusers(mock_context: InvocationContext) -> None:
def test_download_diffusers_subfolder(mock_context: InvocationContext) -> None:
model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo::vae")
model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo::/vae")
assert model_path.is_dir()
assert model_path.name != "vae" # will not create the vae subfolder with preserve_subfolders False
assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists() or (
model_path / "diffusion_pytorch_model.safetensors"
).exists()
def test_download_diffusers_preserve_subfolders(mock_context: InvocationContext) -> None:
model_path = mock_context.models.download_and_cache_model(
"stabilityai/sdxl-turbo::/vae",
preserve_subfolders=True,
)
assert model_path.is_dir()
assert model_path.name == "vae" # will create the vae subfolder with preserve_subfolders True
assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists() or (
model_path / "diffusion_pytorch_model.safetensors"
).exists()