mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
enable downloading from subfolders for repo_ids
This commit is contained in:
parent
676ccd8ebb
commit
034af2d9f8
@ -2,6 +2,7 @@
|
|||||||
Utility (backend) functions used by model_install.py
|
Utility (backend) functions used by model_install.py
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
@ -88,6 +89,7 @@ class ModelLoadInfo:
|
|||||||
base_type: BaseModelType
|
base_type: BaseModelType
|
||||||
path: Optional[Path] = None
|
path: Optional[Path] = None
|
||||||
repo_id: Optional[str] = None
|
repo_id: Optional[str] = None
|
||||||
|
subfolder: Optional[str] = None
|
||||||
description: str = ""
|
description: str = ""
|
||||||
installed: bool = False
|
installed: bool = False
|
||||||
recommended: bool = False
|
recommended: bool = False
|
||||||
@ -126,7 +128,10 @@ class ModelInstall(object):
|
|||||||
value["name"] = name
|
value["name"] = name
|
||||||
value["base_type"] = base
|
value["base_type"] = base
|
||||||
value["model_type"] = model_type
|
value["model_type"] = model_type
|
||||||
model_dict[key] = ModelLoadInfo(**value)
|
model_info = ModelLoadInfo(**value)
|
||||||
|
if model_info.subfolder and model_info.repo_id:
|
||||||
|
model_info.repo_id += f":{model_info.subfolder}"
|
||||||
|
model_dict[key] = model_info
|
||||||
|
|
||||||
# supplement with entries in models.yaml
|
# supplement with entries in models.yaml
|
||||||
installed_models = [x for x in self.mgr.list_models()]
|
installed_models = [x for x in self.mgr.list_models()]
|
||||||
@ -317,46 +322,64 @@ class ModelInstall(object):
|
|||||||
return self._install_path(Path(models_path), info)
|
return self._install_path(Path(models_path), info)
|
||||||
|
|
||||||
def _install_repo(self, repo_id: str) -> AddModelResult:
|
def _install_repo(self, repo_id: str) -> AddModelResult:
|
||||||
|
# hack to recover models stored in subfolders --
|
||||||
|
# Required to get the "v2" model of monster-labs/control_v1p_sd15_qrcode_monster
|
||||||
|
subfolder = None
|
||||||
|
if match := re.match(r"^([^/]+/[^/]+):(\w+)$", repo_id):
|
||||||
|
repo_id = match.group(1)
|
||||||
|
subfolder = match.group(2)
|
||||||
|
|
||||||
hinfo = HfApi().model_info(repo_id)
|
hinfo = HfApi().model_info(repo_id)
|
||||||
|
|
||||||
# we try to figure out how to download this most economically
|
# we try to figure out how to download this most economically
|
||||||
# list all the files in the repo
|
# list all the files in the repo
|
||||||
files = [x.rfilename for x in hinfo.siblings]
|
files = [x.rfilename for x in hinfo.siblings]
|
||||||
|
if subfolder:
|
||||||
|
files = [x for x in files if x.startswith("v2/")]
|
||||||
|
print(f"DEBUG: files={files}")
|
||||||
|
prefix = f"{subfolder}/" if subfolder else ""
|
||||||
|
|
||||||
location = None
|
location = None
|
||||||
|
|
||||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||||
staging = Path(staging)
|
staging = Path(staging)
|
||||||
if "model_index.json" in files:
|
if f"{prefix}model_index.json" in files:
|
||||||
location = self._download_hf_pipeline(repo_id, staging) # pipeline
|
location = self._download_hf_pipeline(repo_id, staging, subfolder=subfolder) # pipeline
|
||||||
elif "unet/model.onnx" in files:
|
elif f"{prefix}unet/model.onnx" in files:
|
||||||
location = self._download_hf_model(repo_id, files, staging)
|
location = self._download_hf_model(repo_id, files, staging)
|
||||||
else:
|
else:
|
||||||
for suffix in ["safetensors", "bin"]:
|
for suffix in ["safetensors", "bin"]:
|
||||||
if f"pytorch_lora_weights.{suffix}" in files:
|
if f"{prefix}pytorch_lora_weights.{suffix}" in files:
|
||||||
location = self._download_hf_model(repo_id, ["pytorch_lora_weights.bin"], staging) # LoRA
|
location = self._download_hf_model(
|
||||||
|
repo_id, [f"pytorch_lora_weights.bin"], staging, subfolder=subfolder
|
||||||
|
) # LoRA
|
||||||
break
|
break
|
||||||
elif (
|
elif (
|
||||||
self.config.precision == "float16" and f"diffusion_pytorch_model.fp16.{suffix}" in files
|
self.config.precision == "float16" and f"{prefix}diffusion_pytorch_model.fp16.{suffix}" in files
|
||||||
): # vae, controlnet or some other standalone
|
): # vae, controlnet or some other standalone
|
||||||
files = ["config.json", f"diffusion_pytorch_model.fp16.{suffix}"]
|
files = [f"config.json", f"diffusion_pytorch_model.fp16.{suffix}"]
|
||||||
location = self._download_hf_model(repo_id, files, staging)
|
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
|
||||||
break
|
break
|
||||||
elif f"diffusion_pytorch_model.{suffix}" in files:
|
elif f"{prefix}diffusion_pytorch_model.{suffix}" in files:
|
||||||
files = ["config.json", f"diffusion_pytorch_model.{suffix}"]
|
files = [f"config.json", f"diffusion_pytorch_model.{suffix}"]
|
||||||
location = self._download_hf_model(repo_id, files, staging)
|
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
|
||||||
break
|
break
|
||||||
elif f"learned_embeds.{suffix}" in files:
|
elif f"{prefix}learned_embeds.{suffix}" in files:
|
||||||
location = self._download_hf_model(repo_id, [f"learned_embeds.{suffix}"], staging)
|
location = self._download_hf_model(
|
||||||
|
repo_id, [f"learned_embeds.{suffix}"], staging, subfolder=subfolder
|
||||||
|
)
|
||||||
break
|
break
|
||||||
elif "image_encoder.txt" in files and f"ip_adapter.{suffix}" in files: # IP-Adapter
|
elif (
|
||||||
files = ["image_encoder.txt", f"ip_adapter.{suffix}"]
|
f"{prefix}image_encoder.txt" in files and f"{prefix}ip_adapter.{suffix}" in files
|
||||||
location = self._download_hf_model(repo_id, files, staging)
|
): # IP-Adapter
|
||||||
|
files = [f"image_encoder.txt", f"ip_adapter.{suffix}"]
|
||||||
|
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
|
||||||
break
|
break
|
||||||
elif f"model.{suffix}" in files and "config.json" in files:
|
elif f"{prefix}model.{suffix}" in files and f"{prefix}config.json" in files:
|
||||||
# This elif-condition is pretty fragile, but it is intended to handle CLIP Vision models hosted
|
# This elif-condition is pretty fragile, but it is intended to handle CLIP Vision models hosted
|
||||||
# by InvokeAI for use with IP-Adapters.
|
# by InvokeAI for use with IP-Adapters.
|
||||||
files = ["config.json", f"model.{suffix}"]
|
files = [f"config.json", f"model.{suffix}"]
|
||||||
location = self._download_hf_model(repo_id, files, staging)
|
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
|
||||||
break
|
break
|
||||||
if not location:
|
if not location:
|
||||||
logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
|
logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
|
||||||
@ -443,15 +466,17 @@ class ModelInstall(object):
|
|||||||
else:
|
else:
|
||||||
return path
|
return path
|
||||||
|
|
||||||
def _download_hf_pipeline(self, repo_id: str, staging: Path) -> Path:
|
def _download_hf_pipeline(self, repo_id: str, staging: Path, subfolder: str = None) -> Path:
|
||||||
"""
|
"""
|
||||||
This retrieves a StableDiffusion model from cache or remote and then
|
Retrieve a StableDiffusion model from cache or remote and then
|
||||||
does a save_pretrained() to the indicated staging area.
|
does a save_pretrained() to the indicated staging area.
|
||||||
"""
|
"""
|
||||||
_, name = repo_id.split("/")
|
_, name = repo_id.split("/")
|
||||||
precision = torch_dtype(choose_torch_device())
|
precision = torch_dtype(choose_torch_device())
|
||||||
variants = ["fp16", None] if precision == torch.float16 else [None, "fp16"]
|
variants = ["fp16", None] if precision == torch.float16 else [None, "fp16"]
|
||||||
|
|
||||||
|
print(f"DEBUG: subfolder = {subfolder}")
|
||||||
|
|
||||||
model = None
|
model = None
|
||||||
for variant in variants:
|
for variant in variants:
|
||||||
try:
|
try:
|
||||||
@ -460,6 +485,7 @@ class ModelInstall(object):
|
|||||||
variant=variant,
|
variant=variant,
|
||||||
torch_dtype=precision,
|
torch_dtype=precision,
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
|
subfolder=subfolder,
|
||||||
)
|
)
|
||||||
except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors
|
except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors
|
||||||
if "fp16" not in str(e):
|
if "fp16" not in str(e):
|
||||||
@ -474,7 +500,7 @@ class ModelInstall(object):
|
|||||||
model.save_pretrained(staging / name, safe_serialization=True)
|
model.save_pretrained(staging / name, safe_serialization=True)
|
||||||
return staging / name
|
return staging / name
|
||||||
|
|
||||||
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path) -> Path:
|
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path:
|
||||||
_, name = repo_id.split("/")
|
_, name = repo_id.split("/")
|
||||||
location = staging / name
|
location = staging / name
|
||||||
paths = list()
|
paths = list()
|
||||||
@ -485,7 +511,7 @@ class ModelInstall(object):
|
|||||||
model_dir=location / filePath.parent,
|
model_dir=location / filePath.parent,
|
||||||
model_name=filePath.name,
|
model_name=filePath.name,
|
||||||
access_token=self.access_token,
|
access_token=self.access_token,
|
||||||
subfolder=filePath.parent,
|
subfolder=filePath.parent / subfolder if subfolder else filePath.parent,
|
||||||
)
|
)
|
||||||
if p:
|
if p:
|
||||||
paths.append(p)
|
paths.append(p)
|
||||||
|
@ -60,6 +60,9 @@ sd-1/main/trinart_stable_diffusion_v2:
|
|||||||
description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB)
|
description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB)
|
||||||
repo_id: naclbit/trinart_stable_diffusion_v2
|
repo_id: naclbit/trinart_stable_diffusion_v2
|
||||||
recommended: False
|
recommended: False
|
||||||
|
sd-1/controlnet/qrcode_monster:
|
||||||
|
repo_id: monster-labs/control_v1p_sd15_qrcode_monster
|
||||||
|
subfolder: v2
|
||||||
sd-1/controlnet/canny:
|
sd-1/controlnet/canny:
|
||||||
repo_id: lllyasviel/control_v11p_sd15_canny
|
repo_id: lllyasviel/control_v11p_sd15_canny
|
||||||
recommended: True
|
recommended: True
|
||||||
|
Loading…
Reference in New Issue
Block a user