mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add support for repo_id subfolders
This commit is contained in:
@ -123,11 +123,20 @@ installation. Examples:
|
|||||||
# (list all controlnet models)
|
# (list all controlnet models)
|
||||||
invokeai-model-install --list controlnet
|
invokeai-model-install --list controlnet
|
||||||
|
|
||||||
# (install the model at the indicated URL)
|
# (install the diffusers model using its hugging face repo_id)
|
||||||
|
invokeai-model-install --add stabilityai/stable-diffusion-xl-base-1.0
|
||||||
|
|
||||||
|
# (install a diffusers model that lives in a subfolder)
|
||||||
|
invokeai-model-install --add stabilityai/stable-diffusion-xl-base-1.0:vae
|
||||||
|
|
||||||
|
# (install the checkpoint model at the indicated URL)
|
||||||
invokeai-model-install --add https://civitai.com/api/download/models/128713
|
invokeai-model-install --add https://civitai.com/api/download/models/128713
|
||||||
|
|
||||||
# (delete the named model)
|
# (delete the named model if its name is unique)
|
||||||
invokeai-model-install --delete sd-1/main/analog-diffusion
|
invokeai-model-install --delete analog-diffusion
|
||||||
|
|
||||||
|
# (delete the named model using its fully qualified name)
|
||||||
|
invokeai-model-install --delete sd-1/main/test_model
|
||||||
```
|
```
|
||||||
|
|
||||||
### Installation via the Web GUI
|
### Installation via the Web GUI
|
||||||
@ -141,6 +150,24 @@ left-hand panel) and navigate to *Import Models*
|
|||||||
wish to install. You may use a URL, HuggingFace repo id, or a path on
|
wish to install. You may use a URL, HuggingFace repo id, or a path on
|
||||||
your local disk.
|
your local disk.
|
||||||
|
|
||||||
|
There is special scanning for CivitAI URLs which lets
|
||||||
|
you cut-and-paste either the URL for a CivitAI model page
|
||||||
|
(e.g. https://civitai.com/models/12345), or the direct download link
|
||||||
|
for a model (e.g. https://civitai.com/api/download/models/12345).
|
||||||
|
|
||||||
|
If the desired model is a HuggingFace diffusers model that is located
|
||||||
|
in a subfolder of the repository (e.g. vae), then append the subfolder
|
||||||
|
to the end of the repo_id like this:
|
||||||
|
|
||||||
|
```
|
||||||
|
# a VAE model located in subfolder "vae"a
|
||||||
|
stabilityai/stable-diffusion-xl-base-1.0:vae
|
||||||
|
|
||||||
|
# version 2 of the model located in subfolder "v2"
|
||||||
|
monster-labs/control_v1p_sd15_qrcode_monster:v2
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
3. Alternatively, the *Scan for Models* button allows you to paste in
|
3. Alternatively, the *Scan for Models* button allows you to paste in
|
||||||
the path to a folder somewhere on your machine. It will be scanned for
|
the path to a folder somewhere on your machine. It will be scanned for
|
||||||
importable models and prompt you to add the ones of your choice.
|
importable models and prompt you to add the ones of your choice.
|
||||||
|
@ -25,11 +25,12 @@ class UnifiedModelInfo(BaseModel):
|
|||||||
base_model: Optional[BaseModelType] = None
|
base_model: Optional[BaseModelType] = None
|
||||||
model_type: Optional[ModelType] = None
|
model_type: Optional[ModelType] = None
|
||||||
source: Optional[str] = None
|
source: Optional[str] = None
|
||||||
|
subfolder: Optional[str] = None
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
recommended: bool = False
|
recommended: bool = False
|
||||||
installed: bool = False
|
installed: bool = False
|
||||||
default: bool = False
|
default: bool = False
|
||||||
requires: Optional[List[str]] = Field(default_factory=list)
|
requires: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -117,6 +118,7 @@ class InstallHelper(object):
|
|||||||
description=self._initial_models[key].get("description"),
|
description=self._initial_models[key].get("description"),
|
||||||
recommended=self._initial_models[key].get("recommended", False),
|
recommended=self._initial_models[key].get("recommended", False),
|
||||||
default=self._initial_models[key].get("default", False),
|
default=self._initial_models[key].get("default", False),
|
||||||
|
subfolder=self._initial_models[key].get("subfolder"),
|
||||||
requires=list(self._initial_models[key].get("requires", [])),
|
requires=list(self._initial_models[key].get("requires", [])),
|
||||||
)
|
)
|
||||||
self.all_models[key] = info
|
self.all_models[key] = info
|
||||||
@ -154,10 +156,8 @@ class InstallHelper(object):
|
|||||||
reverse_source = {x.source: x for x in self.all_models.values()}
|
reverse_source = {x.source: x for x in self.all_models.values()}
|
||||||
additional_models = []
|
additional_models = []
|
||||||
for model_info in model_list:
|
for model_info in model_list:
|
||||||
print(f"DEBUG: model_info={model_info}")
|
|
||||||
for requirement in model_info.requires:
|
for requirement in model_info.requires:
|
||||||
if requirement not in installed:
|
if requirement not in installed:
|
||||||
print(f"DEBUG: installing {requirement}")
|
|
||||||
additional_models.append(reverse_source.get(requirement))
|
additional_models.append(reverse_source.get(requirement))
|
||||||
model_list.extend(additional_models)
|
model_list.extend(additional_models)
|
||||||
|
|
||||||
@ -168,6 +168,7 @@ class InstallHelper(object):
|
|||||||
metadata = ModelSourceMetadata(description=model.description, name=model.name)
|
metadata = ModelSourceMetadata(description=model.description, name=model.name)
|
||||||
installer.install(
|
installer.install(
|
||||||
model.source,
|
model.source,
|
||||||
|
subfolder=model.subfolder,
|
||||||
variant="fp16" if self._config.precision == "float16" else None,
|
variant="fp16" if self._config.precision == "float16" else None,
|
||||||
access_token=ACCESS_TOKEN, # this is a global,
|
access_token=ACCESS_TOKEN, # this is a global,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
|
@ -46,7 +46,7 @@ CIVITAI_VERSIONS_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
|
|||||||
# Regular expressions to describe repo_ids and http urls
|
# Regular expressions to describe repo_ids and http urls
|
||||||
HTTP_RE = r"^https?://"
|
HTTP_RE = r"^https?://"
|
||||||
REPO_ID_RE = r"^[\w-]+/[.\w-]+$"
|
REPO_ID_RE = r"^[\w-]+/[.\w-]+$"
|
||||||
REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE = r"^[\w-]+/[.\w-]+(?::\w+)?$"
|
REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE = r"^([.\w-]+/[.\w-]+)(?::([.\w-]+))?$"
|
||||||
|
|
||||||
|
|
||||||
class DownloadJobPath(DownloadJobBase):
|
class DownloadJobPath(DownloadJobBase):
|
||||||
@ -73,6 +73,9 @@ class DownloadJobRepoID(DownloadJobRemoteSource):
|
|||||||
"""Download repo ids."""
|
"""Download repo ids."""
|
||||||
|
|
||||||
source: str = Field(description="A repo_id (foo/bar), or a repo_id with a subfolder (foo/far:v2)")
|
source: str = Field(description="A repo_id (foo/bar), or a repo_id with a subfolder (foo/far:v2)")
|
||||||
|
subfolder: Optional[str] = Field(
|
||||||
|
description="Provide when the desired model is in a subfolder of the repo_id's distro", default=None
|
||||||
|
)
|
||||||
variant: Optional[str] = Field(description="Variant, such as 'fp16', to download")
|
variant: Optional[str] = Field(description="Variant, such as 'fp16', to download")
|
||||||
subqueue: Optional["DownloadQueueBase"] = Field(
|
subqueue: Optional["DownloadQueueBase"] = Field(
|
||||||
description="a subqueue used for downloading the individual files in the repo_id", default=None
|
description="a subqueue used for downloading the individual files in the repo_id", default=None
|
||||||
@ -572,7 +575,9 @@ class DownloadQueue(DownloadQueueBase):
|
|||||||
variant = job.variant
|
variant = job.variant
|
||||||
if not job.metadata:
|
if not job.metadata:
|
||||||
job.metadata = ModelSourceMetadata()
|
job.metadata = ModelSourceMetadata()
|
||||||
urls_to_download = self._get_repo_info(repo_id, variant=variant, metadata=job.metadata)
|
urls_to_download = self._get_repo_info(
|
||||||
|
repo_id, variant=variant, metadata=job.metadata, subfolder=job.subfolder
|
||||||
|
)
|
||||||
if job.destination.name != Path(repo_id).name:
|
if job.destination.name != Path(repo_id).name:
|
||||||
job.destination = job.destination / Path(repo_id).name
|
job.destination = job.destination / Path(repo_id).name
|
||||||
bytes_downloaded: Dict[int, int] = dict()
|
bytes_downloaded: Dict[int, int] = dict()
|
||||||
@ -605,6 +610,7 @@ class DownloadQueue(DownloadQueueBase):
|
|||||||
repo_id: str,
|
repo_id: str,
|
||||||
metadata: ModelSourceMetadata,
|
metadata: ModelSourceMetadata,
|
||||||
variant: Optional[str] = None,
|
variant: Optional[str] = None,
|
||||||
|
subfolder: Optional[str] = None,
|
||||||
) -> List[Tuple[AnyHttpUrl, Path, Path, int]]:
|
) -> List[Tuple[AnyHttpUrl, Path, Path, int]]:
|
||||||
"""
|
"""
|
||||||
Given a repo_id and an optional variant, return list of URLs to download to get the model.
|
Given a repo_id and an optional variant, return list of URLs to download to get the model.
|
||||||
@ -620,15 +626,26 @@ class DownloadQueue(DownloadQueueBase):
|
|||||||
sibs = model_info.siblings
|
sibs = model_info.siblings
|
||||||
paths = [x.rfilename for x in sibs]
|
paths = [x.rfilename for x in sibs]
|
||||||
sizes = {x.rfilename: x.size for x in sibs}
|
sizes = {x.rfilename: x.size for x in sibs}
|
||||||
if "model_index.json" in paths:
|
|
||||||
url = hf_hub_url(repo_id, filename="model_index.json")
|
prefix = ""
|
||||||
|
if subfolder:
|
||||||
|
prefix = f"{subfolder}/"
|
||||||
|
paths = [x for x in paths if x.startswith(prefix)]
|
||||||
|
|
||||||
|
if f"{prefix}model_index.json" in paths:
|
||||||
|
url = hf_hub_url(repo_id, filename="model_index.json", subfolder=subfolder)
|
||||||
resp = self._requests.get(url)
|
resp = self._requests.get(url)
|
||||||
resp.raise_for_status() # will raise an HTTPError on non-200 status
|
resp.raise_for_status() # will raise an HTTPError on non-200 status
|
||||||
submodels = resp.json()
|
submodels = resp.json()
|
||||||
paths = [x for x in paths if Path(x).parent.as_posix() in submodels]
|
paths = [Path(subfolder or "", x) for x in paths if Path(x).parent.as_posix() in submodels]
|
||||||
paths.insert(0, "model_index.json")
|
paths.insert(0, f"{prefix}model_index.json")
|
||||||
urls = [
|
urls = [
|
||||||
(hf_hub_url(repo_id, filename=x.as_posix()), x.parent or Path("."), Path(x.name), sizes[x.as_posix()])
|
(
|
||||||
|
hf_hub_url(repo_id, filename=x.as_posix()),
|
||||||
|
x.parent.relative_to(prefix) or Path("."),
|
||||||
|
Path(x.name),
|
||||||
|
sizes[x.as_posix()],
|
||||||
|
)
|
||||||
for x in self._select_variants(paths, variant)
|
for x in self._select_variants(paths, variant)
|
||||||
]
|
]
|
||||||
if hasattr(model_info, "cardData"):
|
if hasattr(model_info, "cardData"):
|
||||||
|
@ -73,7 +73,14 @@ from .config import (
|
|||||||
SubModelType,
|
SubModelType,
|
||||||
)
|
)
|
||||||
from .download import DownloadEventHandler, DownloadJobBase, DownloadQueue, DownloadQueueBase, ModelSourceMetadata
|
from .download import DownloadEventHandler, DownloadJobBase, DownloadQueue, DownloadQueueBase, ModelSourceMetadata
|
||||||
from .download.queue import HTTP_RE, REPO_ID_RE, DownloadJobPath, DownloadJobRepoID, DownloadJobURL
|
from .download.queue import (
|
||||||
|
HTTP_RE,
|
||||||
|
REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE,
|
||||||
|
DownloadJobRemoteSource,
|
||||||
|
DownloadJobPath,
|
||||||
|
DownloadJobRepoID,
|
||||||
|
DownloadJobURL,
|
||||||
|
)
|
||||||
from .hash import FastModelHash
|
from .hash import FastModelHash
|
||||||
from .models import InvalidModelException
|
from .models import InvalidModelException
|
||||||
from .probe import ModelProbe, ModelProbeInfo
|
from .probe import ModelProbe, ModelProbeInfo
|
||||||
@ -81,7 +88,7 @@ from .search import ModelSearch
|
|||||||
from .storage import DuplicateModelException, ModelConfigStore
|
from .storage import DuplicateModelException, ModelConfigStore
|
||||||
|
|
||||||
|
|
||||||
class ModelInstallJob(DownloadJobBase):
|
class ModelInstallJob(DownloadJobRemoteSource):
|
||||||
"""This is a version of DownloadJobBase that has an additional slot for the model key and probe info."""
|
"""This is a version of DownloadJobBase that has an additional slot for the model key and probe info."""
|
||||||
|
|
||||||
model_key: Optional[str] = Field(
|
model_key: Optional[str] = Field(
|
||||||
@ -185,6 +192,7 @@ class ModelInstallBase(ABC):
|
|||||||
inplace: bool = True,
|
inplace: bool = True,
|
||||||
priority: int = 10,
|
priority: int = 10,
|
||||||
variant: Optional[str] = None,
|
variant: Optional[str] = None,
|
||||||
|
subfolder: Optional[str] = None,
|
||||||
probe_override: Optional[Dict[str, Any]] = None,
|
probe_override: Optional[Dict[str, Any]] = None,
|
||||||
metadata: Optional[ModelSourceMetadata] = None,
|
metadata: Optional[ModelSourceMetadata] = None,
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
@ -206,6 +214,8 @@ class ModelInstallBase(ABC):
|
|||||||
the models directory, but registered in place (the default).
|
the models directory, but registered in place (the default).
|
||||||
:param variant: For HuggingFace models, this optional parameter
|
:param variant: For HuggingFace models, this optional parameter
|
||||||
specifies which variant to download (e.g. 'fp16')
|
specifies which variant to download (e.g. 'fp16')
|
||||||
|
:param subfolder: When downloading HF repo_ids this can be used to
|
||||||
|
specify a subfolder of the HF repository to download from.
|
||||||
:param probe_override: Optional dict. Any fields in this dict
|
:param probe_override: Optional dict. Any fields in this dict
|
||||||
will override corresponding probe fields. Use it to override
|
will override corresponding probe fields. Use it to override
|
||||||
`base_type`, `model_type`, `format`, `prediction_type` and `image_size`.
|
`base_type`, `model_type`, `format`, `prediction_type` and `image_size`.
|
||||||
@ -525,13 +535,16 @@ class ModelInstall(ModelInstallBase):
|
|||||||
inplace: bool = True,
|
inplace: bool = True,
|
||||||
priority: int = 10,
|
priority: int = 10,
|
||||||
variant: Optional[str] = None,
|
variant: Optional[str] = None,
|
||||||
|
subfolder: Optional[str] = None,
|
||||||
probe_override: Optional[Dict[str, Any]] = None,
|
probe_override: Optional[Dict[str, Any]] = None,
|
||||||
metadata: Optional[ModelSourceMetadata] = None,
|
metadata: Optional[ModelSourceMetadata] = None,
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
) -> DownloadJobBase: # noqa D102
|
) -> DownloadJobBase: # noqa D102
|
||||||
queue = self._download_queue
|
queue = self._download_queue
|
||||||
|
|
||||||
job = self._make_download_job(source, variant=variant, access_token=access_token, priority=priority)
|
job = self._make_download_job(
|
||||||
|
source, variant=variant, access_token=access_token, subfolder=subfolder, priority=priority
|
||||||
|
)
|
||||||
handler = (
|
handler = (
|
||||||
self._complete_registration_handler
|
self._complete_registration_handler
|
||||||
if inplace and Path(source).exists()
|
if inplace and Path(source).exists()
|
||||||
@ -624,6 +637,7 @@ class ModelInstall(ModelInstallBase):
|
|||||||
self,
|
self,
|
||||||
source: Union[str, Path, AnyHttpUrl],
|
source: Union[str, Path, AnyHttpUrl],
|
||||||
variant: Optional[str] = None,
|
variant: Optional[str] = None,
|
||||||
|
subfolder: Optional[str] = None,
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
priority: Optional[int] = 10,
|
priority: Optional[int] = 10,
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
@ -643,9 +657,11 @@ class ModelInstall(ModelInstallBase):
|
|||||||
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
|
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
|
||||||
|
|
||||||
cls = ModelInstallJob
|
cls = ModelInstallJob
|
||||||
if re.match(REPO_ID_RE, str(source)):
|
if match := re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, str(source)):
|
||||||
cls = ModelInstallRepoIDJob
|
cls = ModelInstallRepoIDJob
|
||||||
kwargs = dict(variant=variant)
|
source = match.group(1)
|
||||||
|
subfolder = match.group(2) or subfolder
|
||||||
|
kwargs = dict(variant=variant, subfolder=subfolder)
|
||||||
elif re.match(HTTP_RE, str(source)):
|
elif re.match(HTTP_RE, str(source)):
|
||||||
cls = ModelInstallURLJob
|
cls = ModelInstallURLJob
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
@ -60,9 +60,9 @@ sd-1/main/trinart_stable_diffusion_v2:
|
|||||||
description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB)
|
description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB)
|
||||||
source: naclbit/trinart_stable_diffusion_v2
|
source: naclbit/trinart_stable_diffusion_v2
|
||||||
recommended: False
|
recommended: False
|
||||||
#sd-1/controlnet/qrcode_monster:
|
sd-1/controlnet/qrcode_monster:
|
||||||
# repo_id: monster-labs/control_v1p_sd15_qrcode_monster
|
source: monster-labs/control_v1p_sd15_qrcode_monster
|
||||||
# subfolder: v2
|
subfolder: v2
|
||||||
sd-1/controlnet/canny:
|
sd-1/controlnet/canny:
|
||||||
source: lllyasviel/control_v11p_sd15_canny
|
source: lllyasviel/control_v11p_sd15_canny
|
||||||
recommended: True
|
recommended: True
|
||||||
|
@ -26,7 +26,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
import invokeai.configs as configs
|
import invokeai.configs as configs
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.install.install_helper import InstallHelper
|
from invokeai.backend.install.install_helper import InstallHelper, UnifiedModelInfo
|
||||||
from invokeai.backend.model_manager import BaseModelType, ModelType
|
from invokeai.backend.model_manager import BaseModelType, ModelType
|
||||||
from invokeai.backend.model_manager.install import ModelInstall, ModelInstallJob
|
from invokeai.backend.model_manager.install import ModelInstall, ModelInstallJob
|
||||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||||
@ -56,17 +56,6 @@ NOPRINT_TRANS_TABLE = {i: None for i in range(0, sys.maxunicode + 1) if not chr(
|
|||||||
MAX_OTHER_MODELS = 72
|
MAX_OTHER_MODELS = 72
|
||||||
|
|
||||||
|
|
||||||
class UnifiedModelInfo(BaseModel):
|
|
||||||
name: Optional[str] = None
|
|
||||||
base_model: Optional[BaseModelType] = None
|
|
||||||
model_type: Optional[ModelType] = None
|
|
||||||
source: Optional[str] = None
|
|
||||||
description: Optional[str] = None
|
|
||||||
recommended: bool = False
|
|
||||||
installed: bool = False
|
|
||||||
default: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InstallSelections:
|
class InstallSelections:
|
||||||
install_models: List[UnifiedModelInfo] = field(default_factory=list)
|
install_models: List[UnifiedModelInfo] = field(default_factory=list)
|
||||||
|
Reference in New Issue
Block a user