mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
simplify mapping of detectors to controlnet processors; tweaked download cache name algorithm
This commit is contained in:
parent
1e357bd21b
commit
a40fa8e83b
@ -139,11 +139,11 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
image: ImageField = InputField(description="The image to process")
|
image: ImageField = InputField(description="The image to process")
|
||||||
|
|
||||||
# Map controlnet_aux detector classes to model files in "lllyasviel/Annotators"
|
# Map controlnet_aux detector classes to model files in "lllyasviel/Annotators"
|
||||||
CONTROLNET_PROCESSORS: ClassVar[Dict[type, Path]] = {
|
CONTROLNET_PROCESSORS: ClassVar[Dict[type, str]] = {
|
||||||
MidasDetector: Path("dpt_hybrid-midas-501f0c75.pt"),
|
MidasDetector: "lllyasviel/Annotators::/dpt_hybrid-midas-501f0c75.pt",
|
||||||
MLSDdetector: Path("mlsd_large_512_fp32.pth"),
|
MLSDdetector: "lllyasviel/Annotators::/mlsd_large_512_fp32.pth",
|
||||||
PidiNetDetector: Path("table5_pidinet.pth"),
|
PidiNetDetector: "lllyasviel/Annotators::/table5_pidinet.pth",
|
||||||
ZoeDetector: Path("ZoeD_M12_N.pt"),
|
ZoeDetector: "lllyasviel/Annotators::/ZoeD_M12_N.pt",
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
@ -155,7 +155,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
return context.images.get_pil(self.image.image_name, "RGB")
|
return context.images.get_pil(self.image.image_name, "RGB")
|
||||||
|
|
||||||
def load_processor(self, processor: type) -> LoadedModelWithoutConfig:
|
def load_processor(self, processor: type) -> LoadedModelWithoutConfig:
|
||||||
remote_source = f"lllyasviel/Annotators::/{self.CONTROLNET_PROCESSORS[processor]}"
|
remote_source = self.CONTROLNET_PROCESSORS[processor]
|
||||||
assert hasattr(processor, "from_pretrained") # no common base class for the controlnet processors!
|
assert hasattr(processor, "from_pretrained") # no common base class for the controlnet processors!
|
||||||
model = self._context.models.load_remote_model(
|
model = self._context.models.load_remote_model(
|
||||||
source=remote_source, loader=lambda x: processor.from_pretrained(x.parent, filename=x.name)
|
source=remote_source, loader=lambda x: processor.from_pretrained(x.parent, filename=x.name)
|
||||||
|
@ -26,7 +26,7 @@ def slugify(value: str, allow_unicode: bool = False) -> str:
|
|||||||
value = unicodedata.normalize("NFKC", value)
|
value = unicodedata.normalize("NFKC", value)
|
||||||
else:
|
else:
|
||||||
value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii")
|
value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii")
|
||||||
value = re.sub(r"[/]", "_", value.lower())
|
value = re.sub(r"[/:]+", "_", value.lower())
|
||||||
value = re.sub(r"[^.\w\s-]", "", value.lower())
|
value = re.sub(r"[^.\w\s-]", "", value.lower())
|
||||||
return re.sub(r"[-\s]+", "-", value).strip("-_")
|
return re.sub(r"[-\s]+", "-", value).strip("-_")
|
||||||
|
|
||||||
|
@ -81,8 +81,21 @@ def test_download_diffusers(mock_context: InvocationContext) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_download_diffusers_subfolder(mock_context: InvocationContext) -> None:
|
def test_download_diffusers_subfolder(mock_context: InvocationContext) -> None:
|
||||||
model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo::vae")
|
model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo::/vae")
|
||||||
assert model_path.is_dir()
|
assert model_path.is_dir()
|
||||||
|
assert model_path.name != "vae" # will not create the vae subfolder with preserve_subfolders False
|
||||||
assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists() or (
|
assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists() or (
|
||||||
model_path / "diffusion_pytorch_model.safetensors"
|
model_path / "diffusion_pytorch_model.safetensors"
|
||||||
).exists()
|
).exists()
|
||||||
|
|
||||||
|
def test_download_diffusers_preserve_subfolders(mock_context: InvocationContext) -> None:
|
||||||
|
model_path = mock_context.models.download_and_cache_model(
|
||||||
|
"stabilityai/sdxl-turbo::/vae",
|
||||||
|
preserve_subfolders=True,
|
||||||
|
)
|
||||||
|
assert model_path.is_dir()
|
||||||
|
assert model_path.name == "vae" # will create the vae subfolder with preserve_subfolders True
|
||||||
|
assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists() or (
|
||||||
|
model_path / "diffusion_pytorch_model.safetensors"
|
||||||
|
).exists()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user