more refactoring; HF subfolders not working

This commit is contained in:
Lincoln Stein 2024-05-16 22:26:18 -04:00
parent 911a24479b
commit 2dae5eb7ad
5 changed files with 113 additions and 102 deletions

View File

@ -397,26 +397,25 @@ In the event you wish to create a new installer, you may use the
following initialization pattern: following initialization pattern:
``` ```
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import get_config
from invokeai.app.services.model_records import ModelRecordServiceSQL from invokeai.app.services.model_records import ModelRecordServiceSQL
from invokeai.app.services.model_install import ModelInstallService from invokeai.app.services.model_install import ModelInstallService
from invokeai.app.services.download import DownloadQueueService from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.shared.sqlite import SqliteDatabase from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
config = InvokeAIAppConfig.get_config() config = get_config()
config.parse_args()
logger = InvokeAILogger.get_logger(config=config) logger = InvokeAILogger.get_logger(config=config)
db = SqliteDatabase(config, logger) db = SqliteDatabase(config.db_path, logger)
record_store = ModelRecordServiceSQL(db) record_store = ModelRecordServiceSQL(db)
queue = DownloadQueueService() queue = DownloadQueueService()
queue.start() queue.start()
installer = ModelInstallService(app_config=config, installer = ModelInstallService(app_config=config,
record_store=record_store, record_store=record_store,
download_queue=queue download_queue=queue
) )
installer.start() installer.start()
``` ```

View File

@ -466,17 +466,14 @@ class ModelInstallServiceBase(ABC):
""" """
@abstractmethod @abstractmethod
def download_and_cache_ckpt( def download_and_cache_model(
self, self,
source: str | AnyHttpUrl, source: str,
access_token: Optional[str] = None,
timeout: int = 0,
) -> Path: ) -> Path:
""" """
Download the model file located at source to the models cache and return its Path. Download the model file located at source to the models cache and return its Path.
:param source: A Url or a string that can be converted into one. :param source: A string representing a URL or repo_id.
:param access_token: Optional access token to access restricted resources.
The model file will be downloaded into the system-wide model cache The model file will be downloaded into the system-wide model cache
(`models/.cache`) if it isn't already there. Note that the model cache (`models/.cache`) if it isn't already there. Note that the model cache

View File

@ -9,7 +9,7 @@ from pathlib import Path
from queue import Empty, Queue from queue import Empty, Queue
from shutil import copyfile, copytree, move, rmtree from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Type, Union from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch import torch
import yaml import yaml
@ -18,7 +18,7 @@ from pydantic.networks import AnyHttpUrl
from requests import Session from requests import Session
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob, TqdmProgress from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
@ -208,26 +208,12 @@ class ModelInstallService(ModelInstallServiceBase):
access_token: Optional[str] = None, access_token: Optional[str] = None,
inplace: Optional[bool] = False, inplace: Optional[bool] = False,
) -> ModelInstallJob: ) -> ModelInstallJob:
variants = "|".join(ModelRepoVariant.__members__.values()) """Install a model using pattern matching to infer the type of source."""
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$" source_obj = self._guess_source(source)
source_obj: Optional[StringLikeSource] = None if isinstance(source_obj, LocalModelSource):
source_obj.inplace = inplace
if Path(source).exists(): # A local file or directory elif isinstance(source_obj, HFModelSource) or isinstance(source_obj, URLModelSource):
source_obj = LocalModelSource(path=Path(source), inplace=inplace) source_obj.access_token = access_token
elif match := re.match(hf_repoid_re, source):
source_obj = HFModelSource(
repo_id=match.group(1),
variant=match.group(2) if match.group(2) else None, # pass None rather than ''
subfolder=Path(match.group(3)) if match.group(3) else None,
access_token=access_token,
)
elif re.match(r"^https?://[^/]+", source):
source_obj = URLModelSource(
url=AnyHttpUrl(source),
access_token=access_token,
)
else:
raise ValueError(f"Unsupported model source: '{source}'")
return self.import_model(source_obj, config) return self.import_model(source_obj, config)
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102 def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
@ -383,37 +369,86 @@ class ModelInstallService(ModelInstallServiceBase):
escaped_source = slugify(str(source)) escaped_source = slugify(str(source))
return app_config.download_cache_path / escaped_source return app_config.download_cache_path / escaped_source
def download_and_cache_ckpt( def download_and_cache_model(
self, self,
source: str | AnyHttpUrl, source: str,
access_token: Optional[str] = None,
timeout: int = 0,
) -> Path: ) -> Path:
"""Download the model file located at source to the models cache and return its Path.""" """Download the model file located at source to the models cache and return its Path."""
model_path = self._download_cache_path(source, self._app_config) model_path = self._download_cache_path(str(source), self._app_config)
# We expect the cache directory to contain one and only one downloaded file. # We expect the cache directory to contain one and only one downloaded file or directory.
# We don't know the file's name in advance, as it is set by the download # We don't know the file's name in advance, as it is set by the download
# content-disposition header. # content-disposition header.
if model_path.exists(): if model_path.exists():
contents = [x for x in model_path.iterdir() if x.is_file()] contents: List[Path] = list(model_path.iterdir())
if len(contents) > 0: if len(contents) > 0:
return contents[0] return contents[0]
model_path.mkdir(parents=True, exist_ok=True) model_path.mkdir(parents=True, exist_ok=True)
job = self._download_queue.download( model_source = self._guess_source(source)
source=AnyHttpUrl(str(source)), remote_files, _ = self._remote_files_from_source(model_source)
job = self._download_queue.multifile_download(
parts=remote_files,
dest=model_path, dest=model_path,
access_token=access_token,
on_progress=TqdmProgress().update,
) )
self._download_queue.wait_for_job(job, timeout) files_string = "file" if len(remote_files) == 1 else "file"
self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})")
self._download_queue.wait_for_job(job)
if job.complete: if job.complete:
assert job.download_path is not None assert job.download_path is not None
return job.download_path return job.download_path
else: else:
raise Exception(job.error) raise Exception(job.error)
def _remote_files_from_source(
self, source: ModelSource
) -> Tuple[List[RemoteModelFile], Optional[AnyModelRepoMetadata]]:
metadata = None
if isinstance(source, HFModelSource):
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
assert isinstance(metadata, ModelMetadataWithFiles)
return metadata.download_urls(
variant=source.variant or self._guess_variant(),
subfolder=source.subfolder,
session=self._session,
), metadata
if isinstance(source, URLModelSource):
try:
fetcher = self.get_fetcher_from_url(str(source.url))
kwargs: dict[str, Any] = {"session": self._session}
metadata = fetcher(**kwargs).from_url(source.url)
assert isinstance(metadata, ModelMetadataWithFiles)
return metadata.download_urls(session=self._session), metadata
except ValueError:
pass
return [RemoteModelFile(url=source.url, path=Path("."), size=0)], None
raise Exception(f"No files associated with {source}")
def _guess_source(self, source: str) -> ModelSource:
"""Turn a source string into a ModelSource object."""
variants = "|".join(ModelRepoVariant.__members__.values())
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
source_obj: Optional[StringLikeSource] = None
if Path(source).exists(): # A local file or directory
source_obj = LocalModelSource(path=Path(source))
elif match := re.match(hf_repoid_re, source):
source_obj = HFModelSource(
repo_id=match.group(1),
variant=match.group(2) if match.group(2) else None, # pass None rather than ''
subfolder=Path(match.group(3)) if match.group(3) else None,
)
elif re.match(r"^https?://[^/]+", source):
source_obj = URLModelSource(
url=AnyHttpUrl(source),
)
else:
raise ValueError(f"Unsupported model source: '{source}'")
return source_obj
# -------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------
# Internal functions that manage the installer threads # Internal functions that manage the installer threads
# -------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------
@ -650,18 +685,9 @@ class ModelInstallService(ModelInstallServiceBase):
config: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None,
) -> ModelInstallJob: ) -> ModelInstallJob:
# Add user's cached access token to HuggingFace requests # Add user's cached access token to HuggingFace requests
source.access_token = source.access_token or HfFolder.get_token() if source.access_token is None:
if not source.access_token: source.access_token = HfFolder.get_token()
self._logger.info("No HuggingFace access token present; some models may not be downloadable.") remote_files, metadata = self._remote_files_from_source(source)
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
assert isinstance(metadata, ModelMetadataWithFiles)
remote_files = metadata.download_urls(
variant=source.variant or self._guess_variant(),
subfolder=source.subfolder,
session=self._session,
)
return self._import_remote_model( return self._import_remote_model(
source=source, source=source,
config=config, config=config,
@ -674,21 +700,7 @@ class ModelInstallService(ModelInstallServiceBase):
source: URLModelSource, source: URLModelSource,
config: Optional[Dict[str, Any]], config: Optional[Dict[str, Any]],
) -> ModelInstallJob: ) -> ModelInstallJob:
# URLs from HuggingFace will be handled specially remote_files, metadata = self._remote_files_from_source(source)
metadata = None
fetcher = None
try:
fetcher = self.get_fetcher_from_url(str(source.url))
except ValueError:
pass
kwargs: dict[str, Any] = {"session": self._session}
if fetcher is not None:
metadata = fetcher(**kwargs).from_url(source.url)
self._logger.debug(f"metadata={metadata}")
if metadata and isinstance(metadata, ModelMetadataWithFiles):
remote_files = metadata.download_urls(session=self._session)
else:
remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
return self._import_remote_model( return self._import_remote_model(
source=source, source=source,
config=config, config=config,
@ -733,26 +745,17 @@ class ModelInstallService(ModelInstallServiceBase):
root = Path(".") root = Path(".")
subfolder = Path(".") subfolder = Path(".")
# we remember the path up to the top of the destdir so that it may be
# removed safely at the end of the install process.
install_job._install_tmpdir = destdir
parts: List[RemoteModelFile] = [] parts: List[RemoteModelFile] = []
for model_file in remote_files: for model_file in remote_files:
assert install_job.total_bytes is not None assert install_job.total_bytes is not None
assert model_file.size is not None assert model_file.size is not None
install_job.total_bytes += model_file.size install_job.total_bytes += model_file.size
parts.append(RemoteModelFile(url=model_file.url, path=model_file.path.relative_to(subfolder))) parts.append(RemoteModelFile(url=model_file.url, path=model_file.path.relative_to(subfolder)))
multifile_job = self._download_queue.multifile_download( multifile_job = self._multifile_download(
parts=parts, parts=parts,
dest=destdir, dest=destdir,
access_token=source.access_token, access_token=source.access_token,
submit_job=False, submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict
on_start=self._download_started_callback,
on_progress=self._download_progress_callback,
on_complete=self._download_complete_callback,
on_error=self._download_error_callback,
on_cancelled=self._download_cancelled_callback,
) )
self._download_cache[multifile_job.id] = install_job self._download_cache[multifile_job.id] = install_job
install_job._download_job = multifile_job install_job._download_job = multifile_job
@ -772,6 +775,21 @@ class ModelInstallService(ModelInstallServiceBase):
size += sum(self._stat_size(Path(root, x)) for x in files) size += sum(self._stat_size(Path(root, x)) for x in files)
return size return size
def _multifile_download(
self, parts: List[RemoteModelFile], dest: Path, access_token: Optional[str] = None, submit_job: bool = True
) -> MultiFileDownloadJob:
return self._download_queue.multifile_download(
parts=parts,
dest=dest,
access_token=access_token,
submit_job=submit_job,
on_start=self._download_started_callback,
on_progress=self._download_progress_callback,
on_complete=self._download_complete_callback,
on_error=self._download_error_callback,
on_cancelled=self._download_cancelled_callback,
)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Callbacks are executed by the download queue in a separate thread # Callbacks are executed by the download queue in a separate thread
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@ -875,10 +893,9 @@ class ModelInstallService(ModelInstallServiceBase):
assert job.local_path is not None assert job.local_path is not None
assert job.config_out is not None assert job.config_out is not None
key = job.config_out.key key = job.config_out.key
self._event_bus.emit_model_install_completed(source=str(job.source), self._event_bus.emit_model_install_completed(
key=key, source=str(job.source), key=key, id=job.id, total_bytes=job.bytes
id=job.id, )
total_bytes=job.bytes)
def _signal_job_errored(self, job: ModelInstallJob) -> None: def _signal_job_errored(self, job: ModelInstallJob) -> None:
self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}") self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}")

View File

@ -12,15 +12,13 @@ from pydantic import BaseModel, Field
from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager.config import (
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.config import (
ControlAdapterDefaultSettings, ControlAdapterDefaultSettings,
MainModelDefaultSettings, MainModelDefaultSettings,
ModelFormat,
ModelType,
ModelVariantType, ModelVariantType,
SchedulerPredictionType, SchedulerPredictionType,
) )

View File

@ -222,7 +222,7 @@ def test_delete_register(
store.get_model(key) store.get_model(key)
@pytest.mark.timeout(timeout=20, method="thread") @pytest.mark.timeout(timeout=10, method="thread")
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors")) source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
@ -253,7 +253,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
] ]
@pytest.mark.timeout(timeout=20, method="thread") @pytest.mark.timeout(timeout=10, method="thread")
def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo")) source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
@ -285,9 +285,8 @@ def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_con
} }
@pytest.mark.timeout(timeout=20, method="thread") @pytest.mark.timeout(timeout=10, method="thread")
def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
# TODO: Test subfolder download
source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default) source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default)
bus = mm2_installer.event_bus bus = mm2_installer.event_bus
@ -323,6 +322,7 @@ def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_con
assert job.total_bytes == completed_events[0].payload["total_bytes"] assert job.total_bytes == completed_events[0].payload["total_bytes"]
assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].payload["parts"]) assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].payload["parts"])
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://test.com/missing_model.safetensors")) source = URLModelSource(url=Url("https://test.com/missing_model.safetensors"))
job = mm2_installer.import_model(source) job = mm2_installer.import_model(source)
@ -371,7 +371,7 @@ def test_other_error_during_install(
}, },
], ],
) )
@pytest.mark.timeout(timeout=20, method="thread") @pytest.mark.timeout(timeout=10, method="thread")
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]): def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
"""Test whether or not type is respected on configs when passed to heuristic import.""" """Test whether or not type is respected on configs when passed to heuristic import."""
assert "name" in model_params and "type" in model_params assert "name" in model_params and "type" in model_params
@ -387,7 +387,7 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode
} }
assert "repo_id" in model_params assert "repo_id" in model_params
install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1) install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
mm2_installer.wait_for_job(install_job1, timeout=20) mm2_installer.wait_for_job(install_job1, timeout=10)
if model_params["type"] != "embedding": if model_params["type"] != "embedding":
assert install_job1.errored assert install_job1.errored
assert install_job1.error_type == "InvalidModelConfigException" assert install_job1.error_type == "InvalidModelConfigException"
@ -396,6 +396,6 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode
assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2) install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
mm2_installer.wait_for_job(install_job2, timeout=20) mm2_installer.wait_for_job(install_job2, timeout=10)
assert install_job2.complete assert install_job2.complete
assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out