Multiple refinements on loaders:

- Cache stat collection enabled.
- Implemented ONNX loading.
- Add ability to specify the repo version variant in installer CLI.
- If caller asks for a repo version that doesn't exist, will fall back
  to empty version rather than raising an error.
This commit is contained in:
Lincoln Stein 2024-02-05 21:55:11 -05:00 committed by Brandon Rising
parent fdbd288956
commit 92843d55eb
18 changed files with 215 additions and 49 deletions

View File

@ -495,10 +495,10 @@ class ModelInstallService(ModelInstallServiceBase):
return id return id
@staticmethod @staticmethod
def _guess_variant() -> ModelRepoVariant: def _guess_variant() -> Optional[ModelRepoVariant]:
"""Guess the best HuggingFace variant type to download.""" """Guess the best HuggingFace variant type to download."""
precision = choose_precision(choose_torch_device()) precision = choose_precision(choose_torch_device())
return ModelRepoVariant.FP16 if precision == "float16" else ModelRepoVariant.DEFAULT return ModelRepoVariant.FP16 if precision == "float16" else None
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
return ModelInstallJob( return ModelInstallJob(
@ -523,7 +523,7 @@ class ModelInstallService(ModelInstallServiceBase):
if not source.access_token: if not source.access_token:
self._logger.info("No HuggingFace access token present; some models may not be downloadable.") self._logger.info("No HuggingFace access token present; some models may not be downloadable.")
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id) metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
assert isinstance(metadata, ModelMetadataWithFiles) assert isinstance(metadata, ModelMetadataWithFiles)
remote_files = metadata.download_urls( remote_files = metadata.download_urls(
variant=source.variant or self._guess_variant(), variant=source.variant or self._guess_variant(),

View File

@ -30,6 +30,7 @@ from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import (
BaseModelType, BaseModelType,
InvalidModelConfigException, InvalidModelConfigException,
ModelRepoVariant,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.metadata import UnknownMetadataException from invokeai.backend.model_manager.metadata import UnknownMetadataException
@ -233,11 +234,18 @@ class InstallHelper(object):
if model_path.exists(): # local file on disk if model_path.exists(): # local file on disk
return LocalModelSource(path=model_path.absolute(), inplace=True) return LocalModelSource(path=model_path.absolute(), inplace=True)
if re.match(r"^[^/]+/[^/]+$", model_path_id_or_url): # hugging face repo_id
# parsing huggingface repo ids
# we're going to do a little trick that allows for extended repo_ids of form "foo/bar:fp16"
variants = "|".join([x.lower() for x in ModelRepoVariant.__members__])
if match := re.match(f"^([^/]+/[^/]+?)(?::({variants}))?$", model_path_id_or_url):
repo_id = match.group(1)
repo_variant = ModelRepoVariant(match.group(2)) if match.group(2) else None
return HFModelSource( return HFModelSource(
repo_id=model_path_id_or_url, repo_id=repo_id,
access_token=HfFolder.get_token(), access_token=HfFolder.get_token(),
subfolder=model_info.subfolder, subfolder=model_info.subfolder,
variant=repo_variant,
) )
if re.match(r"^(http|https):", model_path_id_or_url): if re.match(r"^(http|https):", model_path_id_or_url):
return URLModelSource(url=AnyHttpUrl(model_path_id_or_url)) return URLModelSource(url=AnyHttpUrl(model_path_id_or_url))
@ -278,9 +286,11 @@ class InstallHelper(object):
model_name=model_name, model_name=model_name,
) )
if len(matches) > 1: if len(matches) > 1:
print(f"{model} is ambiguous. Please use model_type:model_name (e.g. main:my_model) to disambiguate.") print(
f"{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate."
)
elif not matches: elif not matches:
print(f"{model}: unknown model") print(f"{model_to_remove}: unknown model")
else: else:
for m in matches: for m in matches:
print(f"Deleting {m.type}:{m.name}") print(f"Deleting {m.type}:{m.name}")

View File

@ -109,7 +109,7 @@ class SchedulerPredictionType(str, Enum):
class ModelRepoVariant(str, Enum): class ModelRepoVariant(str, Enum):
"""Various hugging face variants on the diffusers format.""" """Various hugging face variants on the diffusers format."""
DEFAULT = "default" # model files without "fp16" or other qualifier DEFAULT = "" # model files without "fp16" or other qualifier - empty str
FP16 = "fp16" FP16 = "fp16"
FP32 = "fp32" FP32 = "fp32"
ONNX = "onnx" ONNX = "onnx"
@ -246,6 +246,16 @@ class ONNXSD2Config(_MainConfig):
upcast_attention: bool = True upcast_attention: bool = True
class ONNXSDXLConfig(_MainConfig):
"""Model config for ONNX format models based on sdxl."""
type: Literal[ModelType.ONNX] = ModelType.ONNX
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
# No yaml config file for ONNX, so these are part of config
base: Literal[BaseModelType.StableDiffusionXL] = BaseModelType.StableDiffusionXL
prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction
class IPAdapterConfig(ModelConfigBase): class IPAdapterConfig(ModelConfigBase):
"""Model config for IP Adaptor format models.""" """Model config for IP Adaptor format models."""
@ -267,7 +277,7 @@ class T2IConfig(ModelConfigBase):
format: Literal[ModelFormat.Diffusers] format: Literal[ModelFormat.Diffusers]
_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")] _ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config, ONNXSDXLConfig], Field(discriminator="base")]
_ControlNetConfig = Annotated[ _ControlNetConfig = Annotated[
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig], Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
Field(discriminator="format"), Field(discriminator="format"),

View File

@ -16,7 +16,6 @@ from .model_cache.model_cache_default import ModelCache
# This registers the subclasses that implement loaders of specific model types # This registers the subclasses that implement loaders of specific model types
loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"] loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"]
for module in loaders: for module in loaders:
print(f"module={module}")
import_module(f"{__package__}.model_loaders.{module}") import_module(f"{__package__}.model_loaders.{module}")
__all__ = ["AnyModelLoader", "LoadedModel"] __all__ = ["AnyModelLoader", "LoadedModel"]

View File

@ -22,6 +22,7 @@ from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelTy
from invokeai.backend.model_manager.config import VaeCheckpointConfig, VaeDiffusersConfig from invokeai.backend.model_manager.config import VaeCheckpointConfig, VaeDiffusersConfig
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.util.logging import InvokeAILogger
@dataclass @dataclass
@ -88,6 +89,7 @@ class AnyModelLoader:
# this tracks the loader subclasses # this tracks the loader subclasses
_registry: Dict[str, Type[ModelLoaderBase]] = {} _registry: Dict[str, Type[ModelLoaderBase]] = {}
_logger: Logger = InvokeAILogger.get_logger()
def __init__( def __init__(
self, self,
@ -167,7 +169,7 @@ class AnyModelLoader:
"""Define a decorator which registers the subclass of loader.""" """Define a decorator which registers the subclass of loader."""
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]: def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
print("DEBUG: Registering class", subclass.__name__) cls._logger.debug(f"Registering class {subclass.__name__} to load models of type {base}/{type}/{format}")
key = cls._to_registry_key(base, type, format) key = cls._to_registry_key(base, type, format)
cls._registry[key] = subclass cls._registry[key] = subclass
return subclass return subclass

View File

@ -52,7 +52,7 @@ class ModelLoader(ModelLoaderBase):
self._logger = logger self._logger = logger
self._ram_cache = ram_cache self._ram_cache = ram_cache
self._convert_cache = convert_cache self._convert_cache = convert_cache
self._torch_dtype = torch_dtype(choose_torch_device()) self._torch_dtype = torch_dtype(choose_torch_device(), app_config)
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
""" """
@ -102,8 +102,10 @@ class ModelLoader(ModelLoaderBase):
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
) -> ModelLockerBase: ) -> ModelLockerBase:
# TO DO: This is not thread safe! # TO DO: This is not thread safe!
if self._ram_cache.exists(config.key, submodel_type): try:
return self._ram_cache.get(config.key, submodel_type) return self._ram_cache.get(config.key, submodel_type)
except IndexError:
pass
model_variant = getattr(config, "repo_variant", None) model_variant = getattr(config, "repo_variant", None)
self._ram_cache.make_room(self.get_size_fs(config, model_path, submodel_type)) self._ram_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
@ -119,7 +121,11 @@ class ModelLoader(ModelLoaderBase):
size=calc_model_size_by_data(loaded_model), size=calc_model_size_by_data(loaded_model),
) )
return self._ram_cache.get(config.key, submodel_type) return self._ram_cache.get(
key=config.key,
submodel_type=submodel_type,
stats_name=":".join([config.base, config.type, config.name, (submodel_type or "")]),
)
def get_size_fs( def get_size_fs(
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
@ -146,13 +152,21 @@ class ModelLoader(ModelLoaderBase):
# TO DO: Add exception handling # TO DO: Add exception handling
def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin: def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
if submodel_type: if submodel_type:
config = self._load_diffusers_config(model_path, config_name="model_index.json") try:
module, class_name = config[submodel_type.value] config = self._load_diffusers_config(model_path, config_name="model_index.json")
return self._hf_definition_to_type(module=module, class_name=class_name) module, class_name = config[submodel_type.value]
return self._hf_definition_to_type(module=module, class_name=class_name)
except KeyError as e:
raise InvalidModelConfigException(
f'The "{submodel_type}" submodel is not available for this model.'
) from e
else: else:
config = self._load_diffusers_config(model_path, config_name="config.json") try:
class_name = config["_class_name"] config = self._load_diffusers_config(model_path, config_name="config.json")
return self._hf_definition_to_type(module="diffusers", class_name=class_name) class_name = config["_class_name"]
return self._hf_definition_to_type(module="diffusers", class_name=class_name)
except KeyError as e:
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
# This needs to be implemented in subclasses that handle checkpoints # This needs to be implemented in subclasses that handle checkpoints
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:

View File

@ -1,3 +1,4 @@
"""Init file for RamCache.""" """Init file for ModelCache."""
_all__ = ["ModelCacheBase", "ModelCache"] _all__ = ["ModelCacheBase", "ModelCache"]

View File

@ -129,11 +129,17 @@ class ModelCacheBase(ABC, Generic[T]):
self, self,
key: str, key: str,
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
stats_name: Optional[str] = None,
) -> ModelLockerBase: ) -> ModelLockerBase:
""" """
Retrieve model locker object using key and optional submodel_type. Retrieve model using key and optional submodel_type.
This may return an UnknownModelException if the model is not in the cache. :param key: Opaque model key
:param submodel_type: Type of the submodel to fetch
:param stats_name: A human-readable id for the model for the purposes of
stats reporting.
This may raise an IndexError if the model is not in the cache.
""" """
pass pass

View File

@ -24,6 +24,7 @@ import math
import sys import sys
import time import time
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass, field
from logging import Logger from logging import Logger
from typing import Dict, List, Optional from typing import Dict, List, Optional
@ -55,6 +56,20 @@ GIG = 1073741824
MB = 2**20 MB = 2**20
@dataclass
class CacheStats(object):
"""Collect statistics on cache performance."""
hits: int = 0 # cache hits
misses: int = 0 # cache misses
high_watermark: int = 0 # amount of cache used
in_cache: int = 0 # number of models in cache
cleared: int = 0 # number of models cleared to make space
cache_size: int = 0 # total size of cache
# {submodel_key => size}
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
class ModelCache(ModelCacheBase[AnyModel]): class ModelCache(ModelCacheBase[AnyModel]):
"""Implementation of ModelCacheBase.""" """Implementation of ModelCacheBase."""
@ -94,6 +109,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._storage_device: torch.device = storage_device self._storage_device: torch.device = storage_device
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG
# used for stats collection
self.stats = CacheStats()
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {} self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
self._cache_stack: List[str] = [] self._cache_stack: List[str] = []
@ -158,21 +175,40 @@ class ModelCache(ModelCacheBase[AnyModel]):
self, self,
key: str, key: str,
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
stats_name: Optional[str] = None,
) -> ModelLockerBase: ) -> ModelLockerBase:
""" """
Retrieve model using key and optional submodel_type. Retrieve model using key and optional submodel_type.
This may return an IndexError if the model is not in the cache. :param key: Opaque model key
:param submodel_type: Type of the submodel to fetch
:param stats_name: A human-readable id for the model for the purposes of
stats reporting.
This may raise an IndexError if the model is not in the cache.
""" """
key = self._make_cache_key(key, submodel_type) key = self._make_cache_key(key, submodel_type)
if key not in self._cached_models: if key in self._cached_models:
self.stats.hits += 1
else:
self.stats.misses += 1
raise IndexError(f"The model with key {key} is not in the cache.") raise IndexError(f"The model with key {key} is not in the cache.")
cache_entry = self._cached_models[key]
# more stats
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GIG)
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
)
# this moves the entry to the top (right end) of the stack # this moves the entry to the top (right end) of the stack
with suppress(Exception): with suppress(Exception):
self._cache_stack.remove(key) self._cache_stack.remove(key)
self._cache_stack.append(key) self._cache_stack.append(key)
cache_entry = self._cached_models[key]
return ModelLocker( return ModelLocker(
cache=self, cache=self,
cache_entry=cache_entry, cache_entry=cache_entry,

View File

@ -0,0 +1,41 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for Onnx model loading in InvokeAI."""
# This should work the same as Stable Diffusion pipelines
from pathlib import Path
from typing import Optional
from invokeai.backend.model_manager import (
AnyModel,
BaseModelType,
ModelFormat,
ModelRepoVariant,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive)
class OnnyxDiffusersModel(ModelLoader):
"""Class to load onnx models."""
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not submodel_type is not None:
raise Exception("A submodel type must be provided when loading onnx pipelines.")
load_class = self._get_hf_load_class(model_path, submodel_type)
variant = model_variant.value if model_variant else None
model_path = model_path / submodel_type.value
result: AnyModel = load_class.from_pretrained(
model_path,
torch_dtype=self._torch_dtype,
variant=variant,
) # type: ignore
return result

View File

@ -32,6 +32,8 @@ import requests
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from requests.sessions import Session from requests.sessions import Session
from invokeai.backend.model_manager import ModelRepoVariant
from ..metadata_base import ( from ..metadata_base import (
AnyModelRepoMetadata, AnyModelRepoMetadata,
CivitaiMetadata, CivitaiMetadata,
@ -82,10 +84,13 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
return self.from_civitai_versionid(int(version_id)) return self.from_civitai_versionid(int(version_id))
raise UnknownMetadataException("The url '{url}' does not match any known Civitai URL patterns") raise UnknownMetadataException("The url '{url}' does not match any known Civitai URL patterns")
def from_id(self, id: str) -> AnyModelRepoMetadata: def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata:
""" """
Given a Civitai model version ID, return a ModelRepoMetadata object. Given a Civitai model version ID, return a ModelRepoMetadata object.
:param id: An ID.
:param variant: A model variant from the ModelRepoVariant enum (currently ignored)
May raise an `UnknownMetadataException`. May raise an `UnknownMetadataException`.
""" """
return self.from_civitai_versionid(int(id)) return self.from_civitai_versionid(int(id))

View File

@ -18,6 +18,8 @@ from typing import Optional
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from requests.sessions import Session from requests.sessions import Session
from invokeai.backend.model_manager import ModelRepoVariant
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator
@ -45,10 +47,13 @@ class ModelMetadataFetchBase(ABC):
pass pass
@abstractmethod @abstractmethod
def from_id(self, id: str) -> AnyModelRepoMetadata: def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata:
""" """
Given an ID for a model, return a ModelMetadata object. Given an ID for a model, return a ModelMetadata object.
:param id: An ID.
:param variant: A model variant from the ModelRepoVariant enum.
This method will raise a `UnknownMetadataException` This method will raise a `UnknownMetadataException`
in the event that the requested model's metadata is not found at the provided id. in the event that the requested model's metadata is not found at the provided id.
""" """

View File

@ -19,10 +19,12 @@ from typing import Optional
import requests import requests
from huggingface_hub import HfApi, configure_http_backend, hf_hub_url from huggingface_hub import HfApi, configure_http_backend, hf_hub_url
from huggingface_hub.utils._errors import RepositoryNotFoundError from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFoundError
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from requests.sessions import Session from requests.sessions import Session
from invokeai.backend.model_manager import ModelRepoVariant
from ..metadata_base import ( from ..metadata_base import (
AnyModelRepoMetadata, AnyModelRepoMetadata,
HuggingFaceMetadata, HuggingFaceMetadata,
@ -53,12 +55,22 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
metadata = HuggingFaceMetadata.model_validate_json(json) metadata = HuggingFaceMetadata.model_validate_json(json)
return metadata return metadata
def from_id(self, id: str) -> AnyModelRepoMetadata: def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata:
"""Return a HuggingFaceMetadata object given the model's repo_id.""" """Return a HuggingFaceMetadata object given the model's repo_id."""
try: # Little loop which tries fetching a revision corresponding to the selected variant.
model_info = HfApi().model_info(repo_id=id, files_metadata=True) # If not available, then set variant to None and get the default.
except RepositoryNotFoundError as excp: # If this too fails, raise exception.
raise UnknownMetadataException(f"'{id}' not found. See trace for details.") from excp model_info = None
while not model_info:
try:
model_info = HfApi().model_info(repo_id=id, files_metadata=True, revision=variant)
except RepositoryNotFoundError as excp:
raise UnknownMetadataException(f"'{id}' not found. See trace for details.") from excp
except RevisionNotFoundError:
if variant is None:
raise
else:
variant = None
_, name = id.split("/") _, name = id.split("/")
return HuggingFaceMetadata( return HuggingFaceMetadata(
@ -70,7 +82,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
tags=model_info.tags, tags=model_info.tags,
files=[ files=[
RemoteModelFile( RemoteModelFile(
url=hf_hub_url(id, x.rfilename), url=hf_hub_url(id, x.rfilename, revision=variant),
path=Path(name, x.rfilename), path=Path(name, x.rfilename),
size=x.size, size=x.size,
sha256=x.lfs.get("sha256") if x.lfs else None, sha256=x.lfs.get("sha256") if x.lfs else None,

View File

@ -184,7 +184,6 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
[x.path for x in self.files], variant, subfolder [x.path for x in self.files], variant, subfolder
) # all files in the model ) # all files in the model
prefix = f"{subfolder}/" if subfolder else "" prefix = f"{subfolder}/" if subfolder else ""
# the next step reads model_index.json to determine which subdirectories belong # the next step reads model_index.json to determine which subdirectories belong
# to the model # to the model
if Path(f"{prefix}model_index.json") in paths: if Path(f"{prefix}model_index.json") in paths:

View File

@ -7,6 +7,7 @@ import safetensors.torch
import torch import torch
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from invokeai.backend.model_management.models.base import read_checkpoint_meta from invokeai.backend.model_management.models.base import read_checkpoint_meta
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
from invokeai.backend.model_management.util import lora_token_vector_length from invokeai.backend.model_management.util import lora_token_vector_length
@ -590,13 +591,20 @@ class TextualInversionFolderProbe(FolderProbeBase):
return TextualInversionCheckpointProbe(path).get_base_type() return TextualInversionCheckpointProbe(path).get_base_type()
class ONNXFolderProbe(FolderProbeBase): class ONNXFolderProbe(PipelineFolderProbe):
def get_base_type(self) -> BaseModelType:
# Due to the way the installer is set up, the configuration file for safetensors
# will come along for the ride if both the onnx and safetensors forms
# share the same directory. We take advantage of this here.
if (self.model_path / "unet" / "config.json").exists():
return super().get_base_type()
else:
logger.warning('Base type probing is not implemented for ONNX models. Assuming "sd-1"')
return BaseModelType.StableDiffusion1
def get_format(self) -> ModelFormat: def get_format(self) -> ModelFormat:
return ModelFormat("onnx") return ModelFormat("onnx")
def get_base_type(self) -> BaseModelType:
return BaseModelType.StableDiffusion1
def get_variant_type(self) -> ModelVariantType: def get_variant_type(self) -> ModelVariantType:
return ModelVariantType.Normal return ModelVariantType.Normal

View File

@ -41,13 +41,21 @@ def filter_files(
for file in files: for file in files:
if file.name.endswith((".json", ".txt")): if file.name.endswith((".json", ".txt")):
paths.append(file) paths.append(file)
elif file.name.endswith(("learned_embeds.bin", "ip_adapter.bin", "lora_weights.safetensors")): elif file.name.endswith(
(
"learned_embeds.bin",
"ip_adapter.bin",
"lora_weights.safetensors",
"weights.pb",
"onnx_data",
)
):
paths.append(file) paths.append(file)
# BRITTLENESS WARNING!! # BRITTLENESS WARNING!!
# Diffusers models always seem to have "model" in their name, and the regex filter below is applied to avoid # Diffusers models always seem to have "model" in their name, and the regex filter below is applied to avoid
# downloading random checkpoints that might also be in the repo. However there is no guarantee # downloading random checkpoints that might also be in the repo. However there is no guarantee
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models # that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
# will adhere to this naming convention, so this is an area of brittleness. # will adhere to this naming convention, so this is an area to be careful of.
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name): elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
paths.append(file) paths.append(file)
@ -64,7 +72,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
result = set() result = set()
basenames: Dict[Path, Path] = {} basenames: Dict[Path, Path] = {}
for path in files: for path in files:
if path.suffix == ".onnx": if path.suffix in [".onnx", ".pb", ".onnx_data"]:
if variant == ModelRepoVariant.ONNX: if variant == ModelRepoVariant.ONNX:
result.add(path) result.add(path)

View File

@ -29,12 +29,17 @@ def choose_torch_device() -> torch.device:
return torch.device(config.device) return torch.device(config.device)
def choose_precision(device: torch.device) -> str: # We are in transition here from using a single global AppConfig to allowing multiple
"""Returns an appropriate precision for the given torch device""" # configurations. It is strongly recommended to pass the app_config to this function.
def choose_precision(device: torch.device, app_config: Optional[InvokeAIAppConfig] = None) -> str:
"""Return an appropriate precision for the given torch device."""
app_config = app_config or config
if device.type == "cuda": if device.type == "cuda":
device_name = torch.cuda.get_device_name(device) device_name = torch.cuda.get_device_name(device)
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name): if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
if config.precision == "bfloat16": if app_config.precision == "float32":
return "float32"
elif app_config.precision == "bfloat16":
return "bfloat16" return "bfloat16"
else: else:
return "float16" return "float16"
@ -43,9 +48,14 @@ def choose_precision(device: torch.device) -> str:
return "float32" return "float32"
def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype: # We are in transition here from using a single global AppConfig to allowing multiple
# configurations. It is strongly recommended to pass the app_config to this function.
def torch_dtype(
device: Optional[torch.device] = None,
app_config: Optional[InvokeAIAppConfig] = None,
) -> torch.dtype:
device = device or choose_torch_device() device = device or choose_torch_device()
precision = choose_precision(device) precision = choose_precision(device, app_config)
if precision == "float16": if precision == "float16":
return torch.float16 return torch.float16
if precision == "bfloat16": if precision == "bfloat16":

View File

@ -505,7 +505,7 @@ def list_models(installer: ModelInstallService, model_type: ModelType):
print(f"Installed models of type `{model_type}`:") print(f"Installed models of type `{model_type}`:")
for model in models: for model in models:
path = (config.models_path / model.path).resolve() path = (config.models_path / model.path).resolve()
print(f"{model.name:40}{model.base.value:14}{path}") print(f"{model.name:40}{model.base.value:5}{model.type.value:8}{model.format.value:12}{path}")
# -------------------------------------------------------- # --------------------------------------------------------