diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 1c188b300d..d32af4a513 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -495,10 +495,10 @@ class ModelInstallService(ModelInstallServiceBase): return id @staticmethod - def _guess_variant() -> ModelRepoVariant: + def _guess_variant() -> Optional[ModelRepoVariant]: """Guess the best HuggingFace variant type to download.""" precision = choose_precision(choose_torch_device()) - return ModelRepoVariant.FP16 if precision == "float16" else ModelRepoVariant.DEFAULT + return ModelRepoVariant.FP16 if precision == "float16" else None def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: return ModelInstallJob( @@ -523,7 +523,7 @@ class ModelInstallService(ModelInstallServiceBase): if not source.access_token: self._logger.info("No HuggingFace access token present; some models may not be downloadable.") - metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id) + metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant) assert isinstance(metadata, ModelMetadataWithFiles) remote_files = metadata.download_urls( variant=source.variant or self._guess_variant(), diff --git a/invokeai/backend/install/install_helper.py b/invokeai/backend/install/install_helper.py index 9f219132d4..57dfadcaea 100644 --- a/invokeai/backend/install/install_helper.py +++ b/invokeai/backend/install/install_helper.py @@ -30,6 +30,7 @@ from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.backend.model_manager import ( BaseModelType, InvalidModelConfigException, + ModelRepoVariant, ModelType, ) from invokeai.backend.model_manager.metadata import UnknownMetadataException @@ -233,11 +234,18 @@ class InstallHelper(object): if model_path.exists(): # local file on disk return LocalModelSource(path=model_path.absolute(), inplace=True) - if re.match(r"^[^/]+/[^/]+$", model_path_id_or_url): # hugging face repo_id + + # parsing huggingface repo ids + # we're going to do a little trick that allows for extended repo_ids of form "foo/bar:fp16" + variants = "|".join([x.lower() for x in ModelRepoVariant.__members__]) + if match := re.match(f"^([^/]+/[^/]+?)(?::({variants}))?$", model_path_id_or_url): + repo_id = match.group(1) + repo_variant = ModelRepoVariant(match.group(2)) if match.group(2) else None return HFModelSource( - repo_id=model_path_id_or_url, + repo_id=repo_id, access_token=HfFolder.get_token(), subfolder=model_info.subfolder, + variant=repo_variant, ) if re.match(r"^(http|https):", model_path_id_or_url): return URLModelSource(url=AnyHttpUrl(model_path_id_or_url)) @@ -278,9 +286,11 @@ class InstallHelper(object): model_name=model_name, ) if len(matches) > 1: - print(f"{model} is ambiguous. Please use model_type:model_name (e.g. main:my_model) to disambiguate.") + print( + f"{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate." + ) elif not matches: - print(f"{model}: unknown model") + print(f"{model_to_remove}: unknown model") else: for m in matches: print(f"Deleting {m.type}:{m.name}") diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 4488f8eafc..49ce6af2b8 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -109,7 +109,7 @@ class SchedulerPredictionType(str, Enum): class ModelRepoVariant(str, Enum): """Various hugging face variants on the diffusers format.""" - DEFAULT = "default" # model files without "fp16" or other qualifier + DEFAULT = "" # model files without "fp16" or other qualifier - empty str FP16 = "fp16" FP32 = "fp32" ONNX = "onnx" @@ -246,6 +246,16 @@ class ONNXSD2Config(_MainConfig): upcast_attention: bool = True +class ONNXSDXLConfig(_MainConfig): + """Model config for ONNX format models based on sdxl.""" + + type: Literal[ModelType.ONNX] = ModelType.ONNX + format: Literal[ModelFormat.Onnx, ModelFormat.Olive] + # No yaml config file for ONNX, so these are part of config + base: Literal[BaseModelType.StableDiffusionXL] = BaseModelType.StableDiffusionXL + prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction + + class IPAdapterConfig(ModelConfigBase): """Model config for IP Adaptor format models.""" @@ -267,7 +277,7 @@ class T2IConfig(ModelConfigBase): format: Literal[ModelFormat.Diffusers] -_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")] +_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config, ONNXSDXLConfig], Field(discriminator="base")] _ControlNetConfig = Annotated[ Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig], Field(discriminator="format"), diff --git a/invokeai/backend/model_manager/load/__init__.py b/invokeai/backend/model_manager/load/__init__.py index 19b0116ba3..e4c7077f78 100644 --- a/invokeai/backend/model_manager/load/__init__.py +++ b/invokeai/backend/model_manager/load/__init__.py @@ -16,7 +16,6 @@ from .model_cache.model_cache_default import ModelCache # This registers the subclasses that implement loaders of specific model types loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"] for module in loaders: - print(f"module={module}") import_module(f"{__package__}.model_loaders.{module}") __all__ = ["AnyModelLoader", "LoadedModel"] diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 7d4e8337c3..ee9d6d53e3 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -22,6 +22,7 @@ from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelTy from invokeai.backend.model_manager.config import VaeCheckpointConfig, VaeDiffusersConfig from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase +from invokeai.backend.util.logging import InvokeAILogger @dataclass @@ -88,6 +89,7 @@ class AnyModelLoader: # this tracks the loader subclasses _registry: Dict[str, Type[ModelLoaderBase]] = {} + _logger: Logger = InvokeAILogger.get_logger() def __init__( self, @@ -167,7 +169,7 @@ class AnyModelLoader: """Define a decorator which registers the subclass of loader.""" def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]: - print("DEBUG: Registering class", subclass.__name__) + cls._logger.debug(f"Registering class {subclass.__name__} to load models of type {base}/{type}/{format}") key = cls._to_registry_key(base, type, format) cls._registry[key] = subclass return subclass diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index adc84d2051..757745072d 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -52,7 +52,7 @@ class ModelLoader(ModelLoaderBase): self._logger = logger self._ram_cache = ram_cache self._convert_cache = convert_cache - self._torch_dtype = torch_dtype(choose_torch_device()) + self._torch_dtype = torch_dtype(choose_torch_device(), app_config) def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ @@ -102,8 +102,10 @@ class ModelLoader(ModelLoaderBase): self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None ) -> ModelLockerBase: # TO DO: This is not thread safe! - if self._ram_cache.exists(config.key, submodel_type): + try: return self._ram_cache.get(config.key, submodel_type) + except IndexError: + pass model_variant = getattr(config, "repo_variant", None) self._ram_cache.make_room(self.get_size_fs(config, model_path, submodel_type)) @@ -119,7 +121,11 @@ class ModelLoader(ModelLoaderBase): size=calc_model_size_by_data(loaded_model), ) - return self._ram_cache.get(config.key, submodel_type) + return self._ram_cache.get( + key=config.key, + submodel_type=submodel_type, + stats_name=":".join([config.base, config.type, config.name, (submodel_type or "")]), + ) def get_size_fs( self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None @@ -146,13 +152,21 @@ class ModelLoader(ModelLoaderBase): # TO DO: Add exception handling def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin: if submodel_type: - config = self._load_diffusers_config(model_path, config_name="model_index.json") - module, class_name = config[submodel_type.value] - return self._hf_definition_to_type(module=module, class_name=class_name) + try: + config = self._load_diffusers_config(model_path, config_name="model_index.json") + module, class_name = config[submodel_type.value] + return self._hf_definition_to_type(module=module, class_name=class_name) + except KeyError as e: + raise InvalidModelConfigException( + f'The "{submodel_type}" submodel is not available for this model.' + ) from e else: - config = self._load_diffusers_config(model_path, config_name="config.json") - class_name = config["_class_name"] - return self._hf_definition_to_type(module="diffusers", class_name=class_name) + try: + config = self._load_diffusers_config(model_path, config_name="config.json") + class_name = config["_class_name"] + return self._hf_definition_to_type(module="diffusers", class_name=class_name) + except KeyError as e: + raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e # This needs to be implemented in subclasses that handle checkpoints def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: diff --git a/invokeai/backend/model_manager/load/model_cache/__init__.py b/invokeai/backend/model_manager/load/model_cache/__init__.py index 6c87e2519e..0cb5184f3a 100644 --- a/invokeai/backend/model_manager/load/model_cache/__init__.py +++ b/invokeai/backend/model_manager/load/model_cache/__init__.py @@ -1,3 +1,4 @@ -"""Init file for RamCache.""" +"""Init file for ModelCache.""" + _all__ = ["ModelCacheBase", "ModelCache"] diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index 14a7dfb4a1..b1a6768ee8 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -129,11 +129,17 @@ class ModelCacheBase(ABC, Generic[T]): self, key: str, submodel_type: Optional[SubModelType] = None, + stats_name: Optional[str] = None, ) -> ModelLockerBase: """ - Retrieve model locker object using key and optional submodel_type. + Retrieve model using key and optional submodel_type. - This may return an UnknownModelException if the model is not in the cache. + :param key: Opaque model key + :param submodel_type: Type of the submodel to fetch + :param stats_name: A human-readable id for the model for the purposes of + stats reporting. + + This may raise an IndexError if the model is not in the cache. """ pass diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 688be8ceb4..7e30512a58 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -24,6 +24,7 @@ import math import sys import time from contextlib import suppress +from dataclasses import dataclass, field from logging import Logger from typing import Dict, List, Optional @@ -55,6 +56,20 @@ GIG = 1073741824 MB = 2**20 +@dataclass +class CacheStats(object): + """Collect statistics on cache performance.""" + + hits: int = 0 # cache hits + misses: int = 0 # cache misses + high_watermark: int = 0 # amount of cache used + in_cache: int = 0 # number of models in cache + cleared: int = 0 # number of models cleared to make space + cache_size: int = 0 # total size of cache + # {submodel_key => size} + loaded_model_sizes: Dict[str, int] = field(default_factory=dict) + + class ModelCache(ModelCacheBase[AnyModel]): """Implementation of ModelCacheBase.""" @@ -94,6 +109,8 @@ class ModelCache(ModelCacheBase[AnyModel]): self._storage_device: torch.device = storage_device self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG + # used for stats collection + self.stats = CacheStats() self._cached_models: Dict[str, CacheRecord[AnyModel]] = {} self._cache_stack: List[str] = [] @@ -158,21 +175,40 @@ class ModelCache(ModelCacheBase[AnyModel]): self, key: str, submodel_type: Optional[SubModelType] = None, + stats_name: Optional[str] = None, ) -> ModelLockerBase: """ Retrieve model using key and optional submodel_type. - This may return an IndexError if the model is not in the cache. + :param key: Opaque model key + :param submodel_type: Type of the submodel to fetch + :param stats_name: A human-readable id for the model for the purposes of + stats reporting. + + This may raise an IndexError if the model is not in the cache. """ key = self._make_cache_key(key, submodel_type) - if key not in self._cached_models: + if key in self._cached_models: + self.stats.hits += 1 + else: + self.stats.misses += 1 raise IndexError(f"The model with key {key} is not in the cache.") + cache_entry = self._cached_models[key] + + # more stats + stats_name = stats_name or key + self.stats.cache_size = int(self._max_cache_size * GIG) + self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size()) + self.stats.in_cache = len(self._cached_models) + self.stats.loaded_model_sizes[stats_name] = max( + self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size + ) + # this moves the entry to the top (right end) of the stack with suppress(Exception): self._cache_stack.remove(key) self._cache_stack.append(key) - cache_entry = self._cached_models[key] return ModelLocker( cache=self, cache_entry=cache_entry, diff --git a/invokeai/backend/model_manager/load/model_loaders/onnx.py b/invokeai/backend/model_manager/load/model_loaders/onnx.py new file mode 100644 index 0000000000..935a6b7c95 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/onnx.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for Onnx model loading in InvokeAI.""" + +# This should work the same as Stable Diffusion pipelines +from pathlib import Path +from typing import Optional + +from invokeai.backend.model_manager import ( + AnyModel, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from invokeai.backend.model_manager.load.load_default import ModelLoader + + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx) +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive) +class OnnyxDiffusersModel(ModelLoader): + """Class to load onnx models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if not submodel_type is not None: + raise Exception("A submodel type must be provided when loading onnx pipelines.") + load_class = self._get_hf_load_class(model_path, submodel_type) + variant = model_variant.value if model_variant else None + model_path = model_path / submodel_type.value + result: AnyModel = load_class.from_pretrained( + model_path, + torch_dtype=self._torch_dtype, + variant=variant, + ) # type: ignore + return result diff --git a/invokeai/backend/model_manager/metadata/fetch/civitai.py b/invokeai/backend/model_manager/metadata/fetch/civitai.py index 6e41d6f11b..7991f6a748 100644 --- a/invokeai/backend/model_manager/metadata/fetch/civitai.py +++ b/invokeai/backend/model_manager/metadata/fetch/civitai.py @@ -32,6 +32,8 @@ import requests from pydantic.networks import AnyHttpUrl from requests.sessions import Session +from invokeai.backend.model_manager import ModelRepoVariant + from ..metadata_base import ( AnyModelRepoMetadata, CivitaiMetadata, @@ -82,10 +84,13 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase): return self.from_civitai_versionid(int(version_id)) raise UnknownMetadataException("The url '{url}' does not match any known Civitai URL patterns") - def from_id(self, id: str) -> AnyModelRepoMetadata: + def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata: """ Given a Civitai model version ID, return a ModelRepoMetadata object. + :param id: An ID. + :param variant: A model variant from the ModelRepoVariant enum (currently ignored) + May raise an `UnknownMetadataException`. """ return self.from_civitai_versionid(int(id)) diff --git a/invokeai/backend/model_manager/metadata/fetch/fetch_base.py b/invokeai/backend/model_manager/metadata/fetch/fetch_base.py index 58b65b6947..d628ab5c17 100644 --- a/invokeai/backend/model_manager/metadata/fetch/fetch_base.py +++ b/invokeai/backend/model_manager/metadata/fetch/fetch_base.py @@ -18,6 +18,8 @@ from typing import Optional from pydantic.networks import AnyHttpUrl from requests.sessions import Session +from invokeai.backend.model_manager import ModelRepoVariant + from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator @@ -45,10 +47,13 @@ class ModelMetadataFetchBase(ABC): pass @abstractmethod - def from_id(self, id: str) -> AnyModelRepoMetadata: + def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata: """ Given an ID for a model, return a ModelMetadata object. + :param id: An ID. + :param variant: A model variant from the ModelRepoVariant enum. + This method will raise a `UnknownMetadataException` in the event that the requested model's metadata is not found at the provided id. """ diff --git a/invokeai/backend/model_manager/metadata/fetch/huggingface.py b/invokeai/backend/model_manager/metadata/fetch/huggingface.py index 5d1eb0cc9e..6f04e8713b 100644 --- a/invokeai/backend/model_manager/metadata/fetch/huggingface.py +++ b/invokeai/backend/model_manager/metadata/fetch/huggingface.py @@ -19,10 +19,12 @@ from typing import Optional import requests from huggingface_hub import HfApi, configure_http_backend, hf_hub_url -from huggingface_hub.utils._errors import RepositoryNotFoundError +from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFoundError from pydantic.networks import AnyHttpUrl from requests.sessions import Session +from invokeai.backend.model_manager import ModelRepoVariant + from ..metadata_base import ( AnyModelRepoMetadata, HuggingFaceMetadata, @@ -53,12 +55,22 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase): metadata = HuggingFaceMetadata.model_validate_json(json) return metadata - def from_id(self, id: str) -> AnyModelRepoMetadata: + def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata: """Return a HuggingFaceMetadata object given the model's repo_id.""" - try: - model_info = HfApi().model_info(repo_id=id, files_metadata=True) - except RepositoryNotFoundError as excp: - raise UnknownMetadataException(f"'{id}' not found. See trace for details.") from excp + # Little loop which tries fetching a revision corresponding to the selected variant. + # If not available, then set variant to None and get the default. + # If this too fails, raise exception. + model_info = None + while not model_info: + try: + model_info = HfApi().model_info(repo_id=id, files_metadata=True, revision=variant) + except RepositoryNotFoundError as excp: + raise UnknownMetadataException(f"'{id}' not found. See trace for details.") from excp + except RevisionNotFoundError: + if variant is None: + raise + else: + variant = None _, name = id.split("/") return HuggingFaceMetadata( @@ -70,7 +82,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase): tags=model_info.tags, files=[ RemoteModelFile( - url=hf_hub_url(id, x.rfilename), + url=hf_hub_url(id, x.rfilename, revision=variant), path=Path(name, x.rfilename), size=x.size, sha256=x.lfs.get("sha256") if x.lfs else None, diff --git a/invokeai/backend/model_manager/metadata/metadata_base.py b/invokeai/backend/model_manager/metadata/metadata_base.py index 5aa883d26d..5c3afcdc96 100644 --- a/invokeai/backend/model_manager/metadata/metadata_base.py +++ b/invokeai/backend/model_manager/metadata/metadata_base.py @@ -184,7 +184,6 @@ class HuggingFaceMetadata(ModelMetadataWithFiles): [x.path for x in self.files], variant, subfolder ) # all files in the model prefix = f"{subfolder}/" if subfolder else "" - # the next step reads model_index.json to determine which subdirectories belong # to the model if Path(f"{prefix}model_index.json") in paths: diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 64a20a2092..55a9c0464a 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -7,6 +7,7 @@ import safetensors.torch import torch from picklescan.scanner import scan_file_path +import invokeai.backend.util.logging as logger from invokeai.backend.model_management.models.base import read_checkpoint_meta from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat from invokeai.backend.model_management.util import lora_token_vector_length @@ -590,13 +591,20 @@ class TextualInversionFolderProbe(FolderProbeBase): return TextualInversionCheckpointProbe(path).get_base_type() -class ONNXFolderProbe(FolderProbeBase): +class ONNXFolderProbe(PipelineFolderProbe): + def get_base_type(self) -> BaseModelType: + # Due to the way the installer is set up, the configuration file for safetensors + # will come along for the ride if both the onnx and safetensors forms + # share the same directory. We take advantage of this here. + if (self.model_path / "unet" / "config.json").exists(): + return super().get_base_type() + else: + logger.warning('Base type probing is not implemented for ONNX models. Assuming "sd-1"') + return BaseModelType.StableDiffusion1 + def get_format(self) -> ModelFormat: return ModelFormat("onnx") - def get_base_type(self) -> BaseModelType: - return BaseModelType.StableDiffusion1 - def get_variant_type(self) -> ModelVariantType: return ModelVariantType.Normal diff --git a/invokeai/backend/model_manager/util/select_hf_files.py b/invokeai/backend/model_manager/util/select_hf_files.py index 6976059044..a894d915de 100644 --- a/invokeai/backend/model_manager/util/select_hf_files.py +++ b/invokeai/backend/model_manager/util/select_hf_files.py @@ -41,13 +41,21 @@ def filter_files( for file in files: if file.name.endswith((".json", ".txt")): paths.append(file) - elif file.name.endswith(("learned_embeds.bin", "ip_adapter.bin", "lora_weights.safetensors")): + elif file.name.endswith( + ( + "learned_embeds.bin", + "ip_adapter.bin", + "lora_weights.safetensors", + "weights.pb", + "onnx_data", + ) + ): paths.append(file) # BRITTLENESS WARNING!! # Diffusers models always seem to have "model" in their name, and the regex filter below is applied to avoid # downloading random checkpoints that might also be in the repo. However there is no guarantee # that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models - # will adhere to this naming convention, so this is an area of brittleness. + # will adhere to this naming convention, so this is an area to be careful of. elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name): paths.append(file) @@ -64,7 +72,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path result = set() basenames: Dict[Path, Path] = {} for path in files: - if path.suffix == ".onnx": + if path.suffix in [".onnx", ".pb", ".onnx_data"]: if variant == ModelRepoVariant.ONNX: result.add(path) diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index a787f9b6f4..b4f24d8483 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -29,12 +29,17 @@ def choose_torch_device() -> torch.device: return torch.device(config.device) -def choose_precision(device: torch.device) -> str: - """Returns an appropriate precision for the given torch device""" +# We are in transition here from using a single global AppConfig to allowing multiple +# configurations. It is strongly recommended to pass the app_config to this function. +def choose_precision(device: torch.device, app_config: Optional[InvokeAIAppConfig] = None) -> str: + """Return an appropriate precision for the given torch device.""" + app_config = app_config or config if device.type == "cuda": device_name = torch.cuda.get_device_name(device) if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name): - if config.precision == "bfloat16": + if app_config.precision == "float32": + return "float32" + elif app_config.precision == "bfloat16": return "bfloat16" else: return "float16" @@ -43,9 +48,14 @@ def choose_precision(device: torch.device) -> str: return "float32" -def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype: +# We are in transition here from using a single global AppConfig to allowing multiple +# configurations. It is strongly recommended to pass the app_config to this function. +def torch_dtype( + device: Optional[torch.device] = None, + app_config: Optional[InvokeAIAppConfig] = None, +) -> torch.dtype: device = device or choose_torch_device() - precision = choose_precision(device) + precision = choose_precision(device, app_config) if precision == "float16": return torch.float16 if precision == "bfloat16": diff --git a/invokeai/frontend/install/model_install2.py b/invokeai/frontend/install/model_install2.py index 6eb480c8d9..51a633a565 100644 --- a/invokeai/frontend/install/model_install2.py +++ b/invokeai/frontend/install/model_install2.py @@ -505,7 +505,7 @@ def list_models(installer: ModelInstallService, model_type: ModelType): print(f"Installed models of type `{model_type}`:") for model in models: path = (config.models_path / model.path).resolve() - print(f"{model.name:40}{model.base.value:14}{path}") + print(f"{model.name:40}{model.base.value:5}{model.type.value:8}{model.format.value:12}{path}") # --------------------------------------------------------