simplify mapping of detectors to controlnet processors; tweaked download cache name algorithm

This commit is contained in:
Lincoln Stein
2024-07-20 18:34:57 -04:00
parent 1e357bd21b
commit a40fa8e83b
3 changed files with 21 additions and 8 deletions

View File

@ -139,11 +139,11 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
image: ImageField = InputField(description="The image to process")
# Map controlnet_aux detector classes to model files in "lllyasviel/Annotators"
CONTROLNET_PROCESSORS: ClassVar[Dict[type, Path]] = {
MidasDetector: Path("dpt_hybrid-midas-501f0c75.pt"),
MLSDdetector: Path("mlsd_large_512_fp32.pth"),
PidiNetDetector: Path("table5_pidinet.pth"),
ZoeDetector: Path("ZoeD_M12_N.pt"),
CONTROLNET_PROCESSORS: ClassVar[Dict[type, str]] = {
MidasDetector: "lllyasviel/Annotators::/dpt_hybrid-midas-501f0c75.pt",
MLSDdetector: "lllyasviel/Annotators::/mlsd_large_512_fp32.pth",
PidiNetDetector: "lllyasviel/Annotators::/table5_pidinet.pth",
ZoeDetector: "lllyasviel/Annotators::/ZoeD_M12_N.pt",
}
def run_processor(self, image: Image.Image) -> Image.Image:
@ -155,7 +155,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
return context.images.get_pil(self.image.image_name, "RGB")
def load_processor(self, processor: type) -> LoadedModelWithoutConfig:
remote_source = f"lllyasviel/Annotators::/{self.CONTROLNET_PROCESSORS[processor]}"
remote_source = self.CONTROLNET_PROCESSORS[processor]
assert hasattr(processor, "from_pretrained") # no common base class for the controlnet processors!
model = self._context.models.load_remote_model(
source=remote_source, loader=lambda x: processor.from_pretrained(x.parent, filename=x.name)

View File

@ -26,7 +26,7 @@ def slugify(value: str, allow_unicode: bool = False) -> str:
value = unicodedata.normalize("NFKC", value)
else:
value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii")
value = re.sub(r"[/]", "_", value.lower())
value = re.sub(r"[/:]+", "_", value.lower())
value = re.sub(r"[^.\w\s-]", "", value.lower())
return re.sub(r"[-\s]+", "-", value).strip("-_")