add support for repo_id subfolders

This commit is contained in:
Lincoln Stein
2023-10-08 12:45:06 -04:00
parent 51060543dc
commit a64a34b49a
6 changed files with 83 additions and 33 deletions

View File

@ -123,11 +123,20 @@ installation. Examples:
# (list all controlnet models) # (list all controlnet models)
invokeai-model-install --list controlnet invokeai-model-install --list controlnet
# (install the model at the indicated URL) # (install the diffusers model using its hugging face repo_id)
invokeai-model-install --add stabilityai/stable-diffusion-xl-base-1.0
# (install a diffusers model that lives in a subfolder)
invokeai-model-install --add stabilityai/stable-diffusion-xl-base-1.0:vae
# (install the checkpoint model at the indicated URL)
invokeai-model-install --add https://civitai.com/api/download/models/128713 invokeai-model-install --add https://civitai.com/api/download/models/128713
# (delete the named model) # (delete the named model if its name is unique)
invokeai-model-install --delete sd-1/main/analog-diffusion invokeai-model-install --delete analog-diffusion
# (delete the named model using its fully qualified name)
invokeai-model-install --delete sd-1/main/test_model
``` ```
### Installation via the Web GUI ### Installation via the Web GUI
@ -141,6 +150,24 @@ left-hand panel) and navigate to *Import Models*
wish to install. You may use a URL, HuggingFace repo id, or a path on wish to install. You may use a URL, HuggingFace repo id, or a path on
your local disk. your local disk.
There is special scanning for CivitAI URLs which lets
you cut-and-paste either the URL for a CivitAI model page
(e.g. https://civitai.com/models/12345), or the direct download link
for a model (e.g. https://civitai.com/api/download/models/12345).
If the desired model is a HuggingFace diffusers model that is located
in a subfolder of the repository (e.g. vae), then append the subfolder
to the end of the repo_id like this:
```
# a VAE model located in subfolder "vae"a
stabilityai/stable-diffusion-xl-base-1.0:vae
# version 2 of the model located in subfolder "v2"
monster-labs/control_v1p_sd15_qrcode_monster:v2
```
3. Alternatively, the *Scan for Models* button allows you to paste in 3. Alternatively, the *Scan for Models* button allows you to paste in
the path to a folder somewhere on your machine. It will be scanned for the path to a folder somewhere on your machine. It will be scanned for
importable models and prompt you to add the ones of your choice. importable models and prompt you to add the ones of your choice.

View File

@ -25,11 +25,12 @@ class UnifiedModelInfo(BaseModel):
base_model: Optional[BaseModelType] = None base_model: Optional[BaseModelType] = None
model_type: Optional[ModelType] = None model_type: Optional[ModelType] = None
source: Optional[str] = None source: Optional[str] = None
subfolder: Optional[str] = None
description: Optional[str] = None description: Optional[str] = None
recommended: bool = False recommended: bool = False
installed: bool = False installed: bool = False
default: bool = False default: bool = False
requires: Optional[List[str]] = Field(default_factory=list) requires: List[str] = Field(default_factory=list)
@dataclass @dataclass
@ -117,6 +118,7 @@ class InstallHelper(object):
description=self._initial_models[key].get("description"), description=self._initial_models[key].get("description"),
recommended=self._initial_models[key].get("recommended", False), recommended=self._initial_models[key].get("recommended", False),
default=self._initial_models[key].get("default", False), default=self._initial_models[key].get("default", False),
subfolder=self._initial_models[key].get("subfolder"),
requires=list(self._initial_models[key].get("requires", [])), requires=list(self._initial_models[key].get("requires", [])),
) )
self.all_models[key] = info self.all_models[key] = info
@ -154,10 +156,8 @@ class InstallHelper(object):
reverse_source = {x.source: x for x in self.all_models.values()} reverse_source = {x.source: x for x in self.all_models.values()}
additional_models = [] additional_models = []
for model_info in model_list: for model_info in model_list:
print(f"DEBUG: model_info={model_info}")
for requirement in model_info.requires: for requirement in model_info.requires:
if requirement not in installed: if requirement not in installed:
print(f"DEBUG: installing {requirement}")
additional_models.append(reverse_source.get(requirement)) additional_models.append(reverse_source.get(requirement))
model_list.extend(additional_models) model_list.extend(additional_models)
@ -168,6 +168,7 @@ class InstallHelper(object):
metadata = ModelSourceMetadata(description=model.description, name=model.name) metadata = ModelSourceMetadata(description=model.description, name=model.name)
installer.install( installer.install(
model.source, model.source,
subfolder=model.subfolder,
variant="fp16" if self._config.precision == "float16" else None, variant="fp16" if self._config.precision == "float16" else None,
access_token=ACCESS_TOKEN, # this is a global, access_token=ACCESS_TOKEN, # this is a global,
metadata=metadata, metadata=metadata,

View File

@ -46,7 +46,7 @@ CIVITAI_VERSIONS_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
# Regular expressions to describe repo_ids and http urls # Regular expressions to describe repo_ids and http urls
HTTP_RE = r"^https?://" HTTP_RE = r"^https?://"
REPO_ID_RE = r"^[\w-]+/[.\w-]+$" REPO_ID_RE = r"^[\w-]+/[.\w-]+$"
REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE = r"^[\w-]+/[.\w-]+(?::\w+)?$" REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE = r"^([.\w-]+/[.\w-]+)(?::([.\w-]+))?$"
class DownloadJobPath(DownloadJobBase): class DownloadJobPath(DownloadJobBase):
@ -73,6 +73,9 @@ class DownloadJobRepoID(DownloadJobRemoteSource):
"""Download repo ids.""" """Download repo ids."""
source: str = Field(description="A repo_id (foo/bar), or a repo_id with a subfolder (foo/far:v2)") source: str = Field(description="A repo_id (foo/bar), or a repo_id with a subfolder (foo/far:v2)")
subfolder: Optional[str] = Field(
description="Provide when the desired model is in a subfolder of the repo_id's distro", default=None
)
variant: Optional[str] = Field(description="Variant, such as 'fp16', to download") variant: Optional[str] = Field(description="Variant, such as 'fp16', to download")
subqueue: Optional["DownloadQueueBase"] = Field( subqueue: Optional["DownloadQueueBase"] = Field(
description="a subqueue used for downloading the individual files in the repo_id", default=None description="a subqueue used for downloading the individual files in the repo_id", default=None
@ -572,7 +575,9 @@ class DownloadQueue(DownloadQueueBase):
variant = job.variant variant = job.variant
if not job.metadata: if not job.metadata:
job.metadata = ModelSourceMetadata() job.metadata = ModelSourceMetadata()
urls_to_download = self._get_repo_info(repo_id, variant=variant, metadata=job.metadata) urls_to_download = self._get_repo_info(
repo_id, variant=variant, metadata=job.metadata, subfolder=job.subfolder
)
if job.destination.name != Path(repo_id).name: if job.destination.name != Path(repo_id).name:
job.destination = job.destination / Path(repo_id).name job.destination = job.destination / Path(repo_id).name
bytes_downloaded: Dict[int, int] = dict() bytes_downloaded: Dict[int, int] = dict()
@ -605,6 +610,7 @@ class DownloadQueue(DownloadQueueBase):
repo_id: str, repo_id: str,
metadata: ModelSourceMetadata, metadata: ModelSourceMetadata,
variant: Optional[str] = None, variant: Optional[str] = None,
subfolder: Optional[str] = None,
) -> List[Tuple[AnyHttpUrl, Path, Path, int]]: ) -> List[Tuple[AnyHttpUrl, Path, Path, int]]:
""" """
Given a repo_id and an optional variant, return list of URLs to download to get the model. Given a repo_id and an optional variant, return list of URLs to download to get the model.
@ -620,15 +626,26 @@ class DownloadQueue(DownloadQueueBase):
sibs = model_info.siblings sibs = model_info.siblings
paths = [x.rfilename for x in sibs] paths = [x.rfilename for x in sibs]
sizes = {x.rfilename: x.size for x in sibs} sizes = {x.rfilename: x.size for x in sibs}
if "model_index.json" in paths:
url = hf_hub_url(repo_id, filename="model_index.json") prefix = ""
if subfolder:
prefix = f"{subfolder}/"
paths = [x for x in paths if x.startswith(prefix)]
if f"{prefix}model_index.json" in paths:
url = hf_hub_url(repo_id, filename="model_index.json", subfolder=subfolder)
resp = self._requests.get(url) resp = self._requests.get(url)
resp.raise_for_status() # will raise an HTTPError on non-200 status resp.raise_for_status() # will raise an HTTPError on non-200 status
submodels = resp.json() submodels = resp.json()
paths = [x for x in paths if Path(x).parent.as_posix() in submodels] paths = [Path(subfolder or "", x) for x in paths if Path(x).parent.as_posix() in submodels]
paths.insert(0, "model_index.json") paths.insert(0, f"{prefix}model_index.json")
urls = [ urls = [
(hf_hub_url(repo_id, filename=x.as_posix()), x.parent or Path("."), Path(x.name), sizes[x.as_posix()]) (
hf_hub_url(repo_id, filename=x.as_posix()),
x.parent.relative_to(prefix) or Path("."),
Path(x.name),
sizes[x.as_posix()],
)
for x in self._select_variants(paths, variant) for x in self._select_variants(paths, variant)
] ]
if hasattr(model_info, "cardData"): if hasattr(model_info, "cardData"):

View File

@ -73,7 +73,14 @@ from .config import (
SubModelType, SubModelType,
) )
from .download import DownloadEventHandler, DownloadJobBase, DownloadQueue, DownloadQueueBase, ModelSourceMetadata from .download import DownloadEventHandler, DownloadJobBase, DownloadQueue, DownloadQueueBase, ModelSourceMetadata
from .download.queue import HTTP_RE, REPO_ID_RE, DownloadJobPath, DownloadJobRepoID, DownloadJobURL from .download.queue import (
HTTP_RE,
REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE,
DownloadJobRemoteSource,
DownloadJobPath,
DownloadJobRepoID,
DownloadJobURL,
)
from .hash import FastModelHash from .hash import FastModelHash
from .models import InvalidModelException from .models import InvalidModelException
from .probe import ModelProbe, ModelProbeInfo from .probe import ModelProbe, ModelProbeInfo
@ -81,7 +88,7 @@ from .search import ModelSearch
from .storage import DuplicateModelException, ModelConfigStore from .storage import DuplicateModelException, ModelConfigStore
class ModelInstallJob(DownloadJobBase): class ModelInstallJob(DownloadJobRemoteSource):
"""This is a version of DownloadJobBase that has an additional slot for the model key and probe info.""" """This is a version of DownloadJobBase that has an additional slot for the model key and probe info."""
model_key: Optional[str] = Field( model_key: Optional[str] = Field(
@ -185,6 +192,7 @@ class ModelInstallBase(ABC):
inplace: bool = True, inplace: bool = True,
priority: int = 10, priority: int = 10,
variant: Optional[str] = None, variant: Optional[str] = None,
subfolder: Optional[str] = None,
probe_override: Optional[Dict[str, Any]] = None, probe_override: Optional[Dict[str, Any]] = None,
metadata: Optional[ModelSourceMetadata] = None, metadata: Optional[ModelSourceMetadata] = None,
access_token: Optional[str] = None, access_token: Optional[str] = None,
@ -206,6 +214,8 @@ class ModelInstallBase(ABC):
the models directory, but registered in place (the default). the models directory, but registered in place (the default).
:param variant: For HuggingFace models, this optional parameter :param variant: For HuggingFace models, this optional parameter
specifies which variant to download (e.g. 'fp16') specifies which variant to download (e.g. 'fp16')
:param subfolder: When downloading HF repo_ids this can be used to
specify a subfolder of the HF repository to download from.
:param probe_override: Optional dict. Any fields in this dict :param probe_override: Optional dict. Any fields in this dict
will override corresponding probe fields. Use it to override will override corresponding probe fields. Use it to override
`base_type`, `model_type`, `format`, `prediction_type` and `image_size`. `base_type`, `model_type`, `format`, `prediction_type` and `image_size`.
@ -525,13 +535,16 @@ class ModelInstall(ModelInstallBase):
inplace: bool = True, inplace: bool = True,
priority: int = 10, priority: int = 10,
variant: Optional[str] = None, variant: Optional[str] = None,
subfolder: Optional[str] = None,
probe_override: Optional[Dict[str, Any]] = None, probe_override: Optional[Dict[str, Any]] = None,
metadata: Optional[ModelSourceMetadata] = None, metadata: Optional[ModelSourceMetadata] = None,
access_token: Optional[str] = None, access_token: Optional[str] = None,
) -> DownloadJobBase: # noqa D102 ) -> DownloadJobBase: # noqa D102
queue = self._download_queue queue = self._download_queue
job = self._make_download_job(source, variant=variant, access_token=access_token, priority=priority) job = self._make_download_job(
source, variant=variant, access_token=access_token, subfolder=subfolder, priority=priority
)
handler = ( handler = (
self._complete_registration_handler self._complete_registration_handler
if inplace and Path(source).exists() if inplace and Path(source).exists()
@ -624,6 +637,7 @@ class ModelInstall(ModelInstallBase):
self, self,
source: Union[str, Path, AnyHttpUrl], source: Union[str, Path, AnyHttpUrl],
variant: Optional[str] = None, variant: Optional[str] = None,
subfolder: Optional[str] = None,
access_token: Optional[str] = None, access_token: Optional[str] = None,
priority: Optional[int] = 10, priority: Optional[int] = 10,
) -> ModelInstallJob: ) -> ModelInstallJob:
@ -643,9 +657,11 @@ class ModelInstall(ModelInstallBase):
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir) self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
cls = ModelInstallJob cls = ModelInstallJob
if re.match(REPO_ID_RE, str(source)): if match := re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, str(source)):
cls = ModelInstallRepoIDJob cls = ModelInstallRepoIDJob
kwargs = dict(variant=variant) source = match.group(1)
subfolder = match.group(2) or subfolder
kwargs = dict(variant=variant, subfolder=subfolder)
elif re.match(HTTP_RE, str(source)): elif re.match(HTTP_RE, str(source)):
cls = ModelInstallURLJob cls = ModelInstallURLJob
kwargs = {} kwargs = {}

View File

@ -60,9 +60,9 @@ sd-1/main/trinart_stable_diffusion_v2:
description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB) description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB)
source: naclbit/trinart_stable_diffusion_v2 source: naclbit/trinart_stable_diffusion_v2
recommended: False recommended: False
#sd-1/controlnet/qrcode_monster: sd-1/controlnet/qrcode_monster:
# repo_id: monster-labs/control_v1p_sd15_qrcode_monster source: monster-labs/control_v1p_sd15_qrcode_monster
# subfolder: v2 subfolder: v2
sd-1/controlnet/canny: sd-1/controlnet/canny:
source: lllyasviel/control_v11p_sd15_canny source: lllyasviel/control_v11p_sd15_canny
recommended: True recommended: True

View File

@ -26,7 +26,7 @@ from pydantic import BaseModel
import invokeai.configs as configs import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.install.install_helper import InstallHelper from invokeai.backend.install.install_helper import InstallHelper, UnifiedModelInfo
from invokeai.backend.model_manager import BaseModelType, ModelType from invokeai.backend.model_manager import BaseModelType, ModelType
from invokeai.backend.model_manager.install import ModelInstall, ModelInstallJob from invokeai.backend.model_manager.install import ModelInstall, ModelInstallJob
from invokeai.backend.util import choose_precision, choose_torch_device from invokeai.backend.util import choose_precision, choose_torch_device
@ -56,17 +56,6 @@ NOPRINT_TRANS_TABLE = {i: None for i in range(0, sys.maxunicode + 1) if not chr(
MAX_OTHER_MODELS = 72 MAX_OTHER_MODELS = 72
class UnifiedModelInfo(BaseModel):
name: Optional[str] = None
base_model: Optional[BaseModelType] = None
model_type: Optional[ModelType] = None
source: Optional[str] = None
description: Optional[str] = None
recommended: bool = False
installed: bool = False
default: bool = False
@dataclass @dataclass
class InstallSelections: class InstallSelections:
install_models: List[UnifiedModelInfo] = field(default_factory=list) install_models: List[UnifiedModelInfo] = field(default_factory=list)