mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Multiple refinements on loaders:
- Cache stat collection enabled. - Implemented ONNX loading. - Add ability to specify the repo version variant in installer CLI. - If caller asks for a repo version that doesn't exist, will fall back to empty version rather than raising an error.
This commit is contained in:
parent
ad2926a24c
commit
fbded1c0f2
@ -495,10 +495,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
return id
|
return id
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _guess_variant() -> ModelRepoVariant:
|
def _guess_variant() -> Optional[ModelRepoVariant]:
|
||||||
"""Guess the best HuggingFace variant type to download."""
|
"""Guess the best HuggingFace variant type to download."""
|
||||||
precision = choose_precision(choose_torch_device())
|
precision = choose_precision(choose_torch_device())
|
||||||
return ModelRepoVariant.FP16 if precision == "float16" else ModelRepoVariant.DEFAULT
|
return ModelRepoVariant.FP16 if precision == "float16" else None
|
||||||
|
|
||||||
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||||
return ModelInstallJob(
|
return ModelInstallJob(
|
||||||
@ -523,7 +523,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
if not source.access_token:
|
if not source.access_token:
|
||||||
self._logger.info("No HuggingFace access token present; some models may not be downloadable.")
|
self._logger.info("No HuggingFace access token present; some models may not be downloadable.")
|
||||||
|
|
||||||
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id)
|
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
|
||||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||||
remote_files = metadata.download_urls(
|
remote_files = metadata.download_urls(
|
||||||
variant=source.variant or self._guess_variant(),
|
variant=source.variant or self._guess_variant(),
|
||||||
|
@ -30,6 +30,7 @@ from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
|||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
InvalidModelConfigException,
|
InvalidModelConfigException,
|
||||||
|
ModelRepoVariant,
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.metadata import UnknownMetadataException
|
from invokeai.backend.model_manager.metadata import UnknownMetadataException
|
||||||
@ -233,11 +234,18 @@ class InstallHelper(object):
|
|||||||
|
|
||||||
if model_path.exists(): # local file on disk
|
if model_path.exists(): # local file on disk
|
||||||
return LocalModelSource(path=model_path.absolute(), inplace=True)
|
return LocalModelSource(path=model_path.absolute(), inplace=True)
|
||||||
if re.match(r"^[^/]+/[^/]+$", model_path_id_or_url): # hugging face repo_id
|
|
||||||
|
# parsing huggingface repo ids
|
||||||
|
# we're going to do a little trick that allows for extended repo_ids of form "foo/bar:fp16"
|
||||||
|
variants = "|".join([x.lower() for x in ModelRepoVariant.__members__])
|
||||||
|
if match := re.match(f"^([^/]+/[^/]+?)(?::({variants}))?$", model_path_id_or_url):
|
||||||
|
repo_id = match.group(1)
|
||||||
|
repo_variant = ModelRepoVariant(match.group(2)) if match.group(2) else None
|
||||||
return HFModelSource(
|
return HFModelSource(
|
||||||
repo_id=model_path_id_or_url,
|
repo_id=repo_id,
|
||||||
access_token=HfFolder.get_token(),
|
access_token=HfFolder.get_token(),
|
||||||
subfolder=model_info.subfolder,
|
subfolder=model_info.subfolder,
|
||||||
|
variant=repo_variant,
|
||||||
)
|
)
|
||||||
if re.match(r"^(http|https):", model_path_id_or_url):
|
if re.match(r"^(http|https):", model_path_id_or_url):
|
||||||
return URLModelSource(url=AnyHttpUrl(model_path_id_or_url))
|
return URLModelSource(url=AnyHttpUrl(model_path_id_or_url))
|
||||||
@ -278,9 +286,11 @@ class InstallHelper(object):
|
|||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
)
|
)
|
||||||
if len(matches) > 1:
|
if len(matches) > 1:
|
||||||
print(f"{model} is ambiguous. Please use model_type:model_name (e.g. main:my_model) to disambiguate.")
|
print(
|
||||||
|
f"{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate."
|
||||||
|
)
|
||||||
elif not matches:
|
elif not matches:
|
||||||
print(f"{model}: unknown model")
|
print(f"{model_to_remove}: unknown model")
|
||||||
else:
|
else:
|
||||||
for m in matches:
|
for m in matches:
|
||||||
print(f"Deleting {m.type}:{m.name}")
|
print(f"Deleting {m.type}:{m.name}")
|
||||||
|
@ -109,7 +109,7 @@ class SchedulerPredictionType(str, Enum):
|
|||||||
class ModelRepoVariant(str, Enum):
|
class ModelRepoVariant(str, Enum):
|
||||||
"""Various hugging face variants on the diffusers format."""
|
"""Various hugging face variants on the diffusers format."""
|
||||||
|
|
||||||
DEFAULT = "default" # model files without "fp16" or other qualifier
|
DEFAULT = "" # model files without "fp16" or other qualifier - empty str
|
||||||
FP16 = "fp16"
|
FP16 = "fp16"
|
||||||
FP32 = "fp32"
|
FP32 = "fp32"
|
||||||
ONNX = "onnx"
|
ONNX = "onnx"
|
||||||
@ -246,6 +246,16 @@ class ONNXSD2Config(_MainConfig):
|
|||||||
upcast_attention: bool = True
|
upcast_attention: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXSDXLConfig(_MainConfig):
|
||||||
|
"""Model config for ONNX format models based on sdxl."""
|
||||||
|
|
||||||
|
type: Literal[ModelType.ONNX] = ModelType.ONNX
|
||||||
|
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
|
||||||
|
# No yaml config file for ONNX, so these are part of config
|
||||||
|
base: Literal[BaseModelType.StableDiffusionXL] = BaseModelType.StableDiffusionXL
|
||||||
|
prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterConfig(ModelConfigBase):
|
class IPAdapterConfig(ModelConfigBase):
|
||||||
"""Model config for IP Adaptor format models."""
|
"""Model config for IP Adaptor format models."""
|
||||||
|
|
||||||
@ -267,7 +277,7 @@ class T2IConfig(ModelConfigBase):
|
|||||||
format: Literal[ModelFormat.Diffusers]
|
format: Literal[ModelFormat.Diffusers]
|
||||||
|
|
||||||
|
|
||||||
_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")]
|
_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config, ONNXSDXLConfig], Field(discriminator="base")]
|
||||||
_ControlNetConfig = Annotated[
|
_ControlNetConfig = Annotated[
|
||||||
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
|
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
|
||||||
Field(discriminator="format"),
|
Field(discriminator="format"),
|
||||||
|
@ -16,7 +16,6 @@ from .model_cache.model_cache_default import ModelCache
|
|||||||
# This registers the subclasses that implement loaders of specific model types
|
# This registers the subclasses that implement loaders of specific model types
|
||||||
loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"]
|
loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"]
|
||||||
for module in loaders:
|
for module in loaders:
|
||||||
print(f"module={module}")
|
|
||||||
import_module(f"{__package__}.model_loaders.{module}")
|
import_module(f"{__package__}.model_loaders.{module}")
|
||||||
|
|
||||||
__all__ = ["AnyModelLoader", "LoadedModel"]
|
__all__ = ["AnyModelLoader", "LoadedModel"]
|
||||||
|
@ -22,6 +22,7 @@ from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelTy
|
|||||||
from invokeai.backend.model_manager.config import VaeCheckpointConfig, VaeDiffusersConfig
|
from invokeai.backend.model_manager.config import VaeCheckpointConfig, VaeDiffusersConfig
|
||||||
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
|
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
|
||||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -88,6 +89,7 @@ class AnyModelLoader:
|
|||||||
|
|
||||||
# this tracks the loader subclasses
|
# this tracks the loader subclasses
|
||||||
_registry: Dict[str, Type[ModelLoaderBase]] = {}
|
_registry: Dict[str, Type[ModelLoaderBase]] = {}
|
||||||
|
_logger: Logger = InvokeAILogger.get_logger()
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -167,7 +169,7 @@ class AnyModelLoader:
|
|||||||
"""Define a decorator which registers the subclass of loader."""
|
"""Define a decorator which registers the subclass of loader."""
|
||||||
|
|
||||||
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
|
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
|
||||||
print("DEBUG: Registering class", subclass.__name__)
|
cls._logger.debug(f"Registering class {subclass.__name__} to load models of type {base}/{type}/{format}")
|
||||||
key = cls._to_registry_key(base, type, format)
|
key = cls._to_registry_key(base, type, format)
|
||||||
cls._registry[key] = subclass
|
cls._registry[key] = subclass
|
||||||
return subclass
|
return subclass
|
||||||
|
@ -52,7 +52,7 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
self._logger = logger
|
self._logger = logger
|
||||||
self._ram_cache = ram_cache
|
self._ram_cache = ram_cache
|
||||||
self._convert_cache = convert_cache
|
self._convert_cache = convert_cache
|
||||||
self._torch_dtype = torch_dtype(choose_torch_device())
|
self._torch_dtype = torch_dtype(choose_torch_device(), app_config)
|
||||||
|
|
||||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
@ -102,8 +102,10 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||||
) -> ModelLockerBase:
|
) -> ModelLockerBase:
|
||||||
# TO DO: This is not thread safe!
|
# TO DO: This is not thread safe!
|
||||||
if self._ram_cache.exists(config.key, submodel_type):
|
try:
|
||||||
return self._ram_cache.get(config.key, submodel_type)
|
return self._ram_cache.get(config.key, submodel_type)
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
|
||||||
model_variant = getattr(config, "repo_variant", None)
|
model_variant = getattr(config, "repo_variant", None)
|
||||||
self._ram_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
|
self._ram_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
|
||||||
@ -119,7 +121,11 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
size=calc_model_size_by_data(loaded_model),
|
size=calc_model_size_by_data(loaded_model),
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._ram_cache.get(config.key, submodel_type)
|
return self._ram_cache.get(
|
||||||
|
key=config.key,
|
||||||
|
submodel_type=submodel_type,
|
||||||
|
stats_name=":".join([config.base, config.type, config.name, (submodel_type or "")]),
|
||||||
|
)
|
||||||
|
|
||||||
def get_size_fs(
|
def get_size_fs(
|
||||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||||
@ -146,13 +152,21 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
# TO DO: Add exception handling
|
# TO DO: Add exception handling
|
||||||
def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
|
def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
|
||||||
if submodel_type:
|
if submodel_type:
|
||||||
|
try:
|
||||||
config = self._load_diffusers_config(model_path, config_name="model_index.json")
|
config = self._load_diffusers_config(model_path, config_name="model_index.json")
|
||||||
module, class_name = config[submodel_type.value]
|
module, class_name = config[submodel_type.value]
|
||||||
return self._hf_definition_to_type(module=module, class_name=class_name)
|
return self._hf_definition_to_type(module=module, class_name=class_name)
|
||||||
|
except KeyError as e:
|
||||||
|
raise InvalidModelConfigException(
|
||||||
|
f'The "{submodel_type}" submodel is not available for this model.'
|
||||||
|
) from e
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
config = self._load_diffusers_config(model_path, config_name="config.json")
|
config = self._load_diffusers_config(model_path, config_name="config.json")
|
||||||
class_name = config["_class_name"]
|
class_name = config["_class_name"]
|
||||||
return self._hf_definition_to_type(module="diffusers", class_name=class_name)
|
return self._hf_definition_to_type(module="diffusers", class_name=class_name)
|
||||||
|
except KeyError as e:
|
||||||
|
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
||||||
|
|
||||||
# This needs to be implemented in subclasses that handle checkpoints
|
# This needs to be implemented in subclasses that handle checkpoints
|
||||||
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
|
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
"""Init file for RamCache."""
|
"""Init file for ModelCache."""
|
||||||
|
|
||||||
|
|
||||||
_all__ = ["ModelCacheBase", "ModelCache"]
|
_all__ = ["ModelCacheBase", "ModelCache"]
|
||||||
|
@ -129,11 +129,17 @@ class ModelCacheBase(ABC, Generic[T]):
|
|||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
stats_name: Optional[str] = None,
|
||||||
) -> ModelLockerBase:
|
) -> ModelLockerBase:
|
||||||
"""
|
"""
|
||||||
Retrieve model locker object using key and optional submodel_type.
|
Retrieve model using key and optional submodel_type.
|
||||||
|
|
||||||
This may return an UnknownModelException if the model is not in the cache.
|
:param key: Opaque model key
|
||||||
|
:param submodel_type: Type of the submodel to fetch
|
||||||
|
:param stats_name: A human-readable id for the model for the purposes of
|
||||||
|
stats reporting.
|
||||||
|
|
||||||
|
This may raise an IndexError if the model is not in the cache.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ import math
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
@ -55,6 +56,20 @@ GIG = 1073741824
|
|||||||
MB = 2**20
|
MB = 2**20
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheStats(object):
|
||||||
|
"""Collect statistics on cache performance."""
|
||||||
|
|
||||||
|
hits: int = 0 # cache hits
|
||||||
|
misses: int = 0 # cache misses
|
||||||
|
high_watermark: int = 0 # amount of cache used
|
||||||
|
in_cache: int = 0 # number of models in cache
|
||||||
|
cleared: int = 0 # number of models cleared to make space
|
||||||
|
cache_size: int = 0 # total size of cache
|
||||||
|
# {submodel_key => size}
|
||||||
|
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class ModelCache(ModelCacheBase[AnyModel]):
|
class ModelCache(ModelCacheBase[AnyModel]):
|
||||||
"""Implementation of ModelCacheBase."""
|
"""Implementation of ModelCacheBase."""
|
||||||
|
|
||||||
@ -94,6 +109,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
self._storage_device: torch.device = storage_device
|
self._storage_device: torch.device = storage_device
|
||||||
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
||||||
self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG
|
self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG
|
||||||
|
# used for stats collection
|
||||||
|
self.stats = CacheStats()
|
||||||
|
|
||||||
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
|
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
|
||||||
self._cache_stack: List[str] = []
|
self._cache_stack: List[str] = []
|
||||||
@ -158,21 +175,40 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
stats_name: Optional[str] = None,
|
||||||
) -> ModelLockerBase:
|
) -> ModelLockerBase:
|
||||||
"""
|
"""
|
||||||
Retrieve model using key and optional submodel_type.
|
Retrieve model using key and optional submodel_type.
|
||||||
|
|
||||||
This may return an IndexError if the model is not in the cache.
|
:param key: Opaque model key
|
||||||
|
:param submodel_type: Type of the submodel to fetch
|
||||||
|
:param stats_name: A human-readable id for the model for the purposes of
|
||||||
|
stats reporting.
|
||||||
|
|
||||||
|
This may raise an IndexError if the model is not in the cache.
|
||||||
"""
|
"""
|
||||||
key = self._make_cache_key(key, submodel_type)
|
key = self._make_cache_key(key, submodel_type)
|
||||||
if key not in self._cached_models:
|
if key in self._cached_models:
|
||||||
|
self.stats.hits += 1
|
||||||
|
else:
|
||||||
|
self.stats.misses += 1
|
||||||
raise IndexError(f"The model with key {key} is not in the cache.")
|
raise IndexError(f"The model with key {key} is not in the cache.")
|
||||||
|
|
||||||
|
cache_entry = self._cached_models[key]
|
||||||
|
|
||||||
|
# more stats
|
||||||
|
stats_name = stats_name or key
|
||||||
|
self.stats.cache_size = int(self._max_cache_size * GIG)
|
||||||
|
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
||||||
|
self.stats.in_cache = len(self._cached_models)
|
||||||
|
self.stats.loaded_model_sizes[stats_name] = max(
|
||||||
|
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
|
||||||
|
)
|
||||||
|
|
||||||
# this moves the entry to the top (right end) of the stack
|
# this moves the entry to the top (right end) of the stack
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
self._cache_stack.remove(key)
|
self._cache_stack.remove(key)
|
||||||
self._cache_stack.append(key)
|
self._cache_stack.append(key)
|
||||||
cache_entry = self._cached_models[key]
|
|
||||||
return ModelLocker(
|
return ModelLocker(
|
||||||
cache=self,
|
cache=self,
|
||||||
cache_entry=cache_entry,
|
cache_entry=cache_entry,
|
||||||
|
41
invokeai/backend/model_manager/load/model_loaders/onnx.py
Normal file
41
invokeai/backend/model_manager/load/model_loaders/onnx.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||||
|
"""Class for Onnx model loading in InvokeAI."""
|
||||||
|
|
||||||
|
# This should work the same as Stable Diffusion pipelines
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager import (
|
||||||
|
AnyModel,
|
||||||
|
BaseModelType,
|
||||||
|
ModelFormat,
|
||||||
|
ModelRepoVariant,
|
||||||
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
|
)
|
||||||
|
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||||
|
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||||
|
|
||||||
|
|
||||||
|
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx)
|
||||||
|
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive)
|
||||||
|
class OnnyxDiffusersModel(ModelLoader):
|
||||||
|
"""Class to load onnx models."""
|
||||||
|
|
||||||
|
def _load_model(
|
||||||
|
self,
|
||||||
|
model_path: Path,
|
||||||
|
model_variant: Optional[ModelRepoVariant] = None,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
) -> AnyModel:
|
||||||
|
if not submodel_type is not None:
|
||||||
|
raise Exception("A submodel type must be provided when loading onnx pipelines.")
|
||||||
|
load_class = self._get_hf_load_class(model_path, submodel_type)
|
||||||
|
variant = model_variant.value if model_variant else None
|
||||||
|
model_path = model_path / submodel_type.value
|
||||||
|
result: AnyModel = load_class.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
torch_dtype=self._torch_dtype,
|
||||||
|
variant=variant,
|
||||||
|
) # type: ignore
|
||||||
|
return result
|
@ -32,6 +32,8 @@ import requests
|
|||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
from requests.sessions import Session
|
from requests.sessions import Session
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager import ModelRepoVariant
|
||||||
|
|
||||||
from ..metadata_base import (
|
from ..metadata_base import (
|
||||||
AnyModelRepoMetadata,
|
AnyModelRepoMetadata,
|
||||||
CivitaiMetadata,
|
CivitaiMetadata,
|
||||||
@ -82,10 +84,13 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
|||||||
return self.from_civitai_versionid(int(version_id))
|
return self.from_civitai_versionid(int(version_id))
|
||||||
raise UnknownMetadataException("The url '{url}' does not match any known Civitai URL patterns")
|
raise UnknownMetadataException("The url '{url}' does not match any known Civitai URL patterns")
|
||||||
|
|
||||||
def from_id(self, id: str) -> AnyModelRepoMetadata:
|
def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata:
|
||||||
"""
|
"""
|
||||||
Given a Civitai model version ID, return a ModelRepoMetadata object.
|
Given a Civitai model version ID, return a ModelRepoMetadata object.
|
||||||
|
|
||||||
|
:param id: An ID.
|
||||||
|
:param variant: A model variant from the ModelRepoVariant enum (currently ignored)
|
||||||
|
|
||||||
May raise an `UnknownMetadataException`.
|
May raise an `UnknownMetadataException`.
|
||||||
"""
|
"""
|
||||||
return self.from_civitai_versionid(int(id))
|
return self.from_civitai_versionid(int(id))
|
||||||
|
@ -18,6 +18,8 @@ from typing import Optional
|
|||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
from requests.sessions import Session
|
from requests.sessions import Session
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager import ModelRepoVariant
|
||||||
|
|
||||||
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator
|
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator
|
||||||
|
|
||||||
|
|
||||||
@ -45,10 +47,13 @@ class ModelMetadataFetchBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def from_id(self, id: str) -> AnyModelRepoMetadata:
|
def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata:
|
||||||
"""
|
"""
|
||||||
Given an ID for a model, return a ModelMetadata object.
|
Given an ID for a model, return a ModelMetadata object.
|
||||||
|
|
||||||
|
:param id: An ID.
|
||||||
|
:param variant: A model variant from the ModelRepoVariant enum.
|
||||||
|
|
||||||
This method will raise a `UnknownMetadataException`
|
This method will raise a `UnknownMetadataException`
|
||||||
in the event that the requested model's metadata is not found at the provided id.
|
in the event that the requested model's metadata is not found at the provided id.
|
||||||
"""
|
"""
|
||||||
|
@ -19,10 +19,12 @@ from typing import Optional
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
from huggingface_hub import HfApi, configure_http_backend, hf_hub_url
|
from huggingface_hub import HfApi, configure_http_backend, hf_hub_url
|
||||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFoundError
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
from requests.sessions import Session
|
from requests.sessions import Session
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager import ModelRepoVariant
|
||||||
|
|
||||||
from ..metadata_base import (
|
from ..metadata_base import (
|
||||||
AnyModelRepoMetadata,
|
AnyModelRepoMetadata,
|
||||||
HuggingFaceMetadata,
|
HuggingFaceMetadata,
|
||||||
@ -53,12 +55,22 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
|||||||
metadata = HuggingFaceMetadata.model_validate_json(json)
|
metadata = HuggingFaceMetadata.model_validate_json(json)
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
def from_id(self, id: str) -> AnyModelRepoMetadata:
|
def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata:
|
||||||
"""Return a HuggingFaceMetadata object given the model's repo_id."""
|
"""Return a HuggingFaceMetadata object given the model's repo_id."""
|
||||||
|
# Little loop which tries fetching a revision corresponding to the selected variant.
|
||||||
|
# If not available, then set variant to None and get the default.
|
||||||
|
# If this too fails, raise exception.
|
||||||
|
model_info = None
|
||||||
|
while not model_info:
|
||||||
try:
|
try:
|
||||||
model_info = HfApi().model_info(repo_id=id, files_metadata=True)
|
model_info = HfApi().model_info(repo_id=id, files_metadata=True, revision=variant)
|
||||||
except RepositoryNotFoundError as excp:
|
except RepositoryNotFoundError as excp:
|
||||||
raise UnknownMetadataException(f"'{id}' not found. See trace for details.") from excp
|
raise UnknownMetadataException(f"'{id}' not found. See trace for details.") from excp
|
||||||
|
except RevisionNotFoundError:
|
||||||
|
if variant is None:
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
variant = None
|
||||||
|
|
||||||
_, name = id.split("/")
|
_, name = id.split("/")
|
||||||
return HuggingFaceMetadata(
|
return HuggingFaceMetadata(
|
||||||
@ -70,7 +82,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
|||||||
tags=model_info.tags,
|
tags=model_info.tags,
|
||||||
files=[
|
files=[
|
||||||
RemoteModelFile(
|
RemoteModelFile(
|
||||||
url=hf_hub_url(id, x.rfilename),
|
url=hf_hub_url(id, x.rfilename, revision=variant),
|
||||||
path=Path(name, x.rfilename),
|
path=Path(name, x.rfilename),
|
||||||
size=x.size,
|
size=x.size,
|
||||||
sha256=x.lfs.get("sha256") if x.lfs else None,
|
sha256=x.lfs.get("sha256") if x.lfs else None,
|
||||||
|
@ -184,7 +184,6 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
|
|||||||
[x.path for x in self.files], variant, subfolder
|
[x.path for x in self.files], variant, subfolder
|
||||||
) # all files in the model
|
) # all files in the model
|
||||||
prefix = f"{subfolder}/" if subfolder else ""
|
prefix = f"{subfolder}/" if subfolder else ""
|
||||||
|
|
||||||
# the next step reads model_index.json to determine which subdirectories belong
|
# the next step reads model_index.json to determine which subdirectories belong
|
||||||
# to the model
|
# to the model
|
||||||
if Path(f"{prefix}model_index.json") in paths:
|
if Path(f"{prefix}model_index.json") in paths:
|
||||||
|
@ -7,6 +7,7 @@ import safetensors.torch
|
|||||||
import torch
|
import torch
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.model_management.models.base import read_checkpoint_meta
|
from invokeai.backend.model_management.models.base import read_checkpoint_meta
|
||||||
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
|
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
|
||||||
from invokeai.backend.model_management.util import lora_token_vector_length
|
from invokeai.backend.model_management.util import lora_token_vector_length
|
||||||
@ -590,13 +591,20 @@ class TextualInversionFolderProbe(FolderProbeBase):
|
|||||||
return TextualInversionCheckpointProbe(path).get_base_type()
|
return TextualInversionCheckpointProbe(path).get_base_type()
|
||||||
|
|
||||||
|
|
||||||
class ONNXFolderProbe(FolderProbeBase):
|
class ONNXFolderProbe(PipelineFolderProbe):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
# Due to the way the installer is set up, the configuration file for safetensors
|
||||||
|
# will come along for the ride if both the onnx and safetensors forms
|
||||||
|
# share the same directory. We take advantage of this here.
|
||||||
|
if (self.model_path / "unet" / "config.json").exists():
|
||||||
|
return super().get_base_type()
|
||||||
|
else:
|
||||||
|
logger.warning('Base type probing is not implemented for ONNX models. Assuming "sd-1"')
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
|
||||||
def get_format(self) -> ModelFormat:
|
def get_format(self) -> ModelFormat:
|
||||||
return ModelFormat("onnx")
|
return ModelFormat("onnx")
|
||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
|
||||||
return BaseModelType.StableDiffusion1
|
|
||||||
|
|
||||||
def get_variant_type(self) -> ModelVariantType:
|
def get_variant_type(self) -> ModelVariantType:
|
||||||
return ModelVariantType.Normal
|
return ModelVariantType.Normal
|
||||||
|
|
||||||
|
@ -41,13 +41,21 @@ def filter_files(
|
|||||||
for file in files:
|
for file in files:
|
||||||
if file.name.endswith((".json", ".txt")):
|
if file.name.endswith((".json", ".txt")):
|
||||||
paths.append(file)
|
paths.append(file)
|
||||||
elif file.name.endswith(("learned_embeds.bin", "ip_adapter.bin", "lora_weights.safetensors")):
|
elif file.name.endswith(
|
||||||
|
(
|
||||||
|
"learned_embeds.bin",
|
||||||
|
"ip_adapter.bin",
|
||||||
|
"lora_weights.safetensors",
|
||||||
|
"weights.pb",
|
||||||
|
"onnx_data",
|
||||||
|
)
|
||||||
|
):
|
||||||
paths.append(file)
|
paths.append(file)
|
||||||
# BRITTLENESS WARNING!!
|
# BRITTLENESS WARNING!!
|
||||||
# Diffusers models always seem to have "model" in their name, and the regex filter below is applied to avoid
|
# Diffusers models always seem to have "model" in their name, and the regex filter below is applied to avoid
|
||||||
# downloading random checkpoints that might also be in the repo. However there is no guarantee
|
# downloading random checkpoints that might also be in the repo. However there is no guarantee
|
||||||
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
|
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
|
||||||
# will adhere to this naming convention, so this is an area of brittleness.
|
# will adhere to this naming convention, so this is an area to be careful of.
|
||||||
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
|
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
|
||||||
paths.append(file)
|
paths.append(file)
|
||||||
|
|
||||||
@ -64,7 +72,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
|||||||
result = set()
|
result = set()
|
||||||
basenames: Dict[Path, Path] = {}
|
basenames: Dict[Path, Path] = {}
|
||||||
for path in files:
|
for path in files:
|
||||||
if path.suffix == ".onnx":
|
if path.suffix in [".onnx", ".pb", ".onnx_data"]:
|
||||||
if variant == ModelRepoVariant.ONNX:
|
if variant == ModelRepoVariant.ONNX:
|
||||||
result.add(path)
|
result.add(path)
|
||||||
|
|
||||||
|
@ -29,12 +29,17 @@ def choose_torch_device() -> torch.device:
|
|||||||
return torch.device(config.device)
|
return torch.device(config.device)
|
||||||
|
|
||||||
|
|
||||||
def choose_precision(device: torch.device) -> str:
|
# We are in transition here from using a single global AppConfig to allowing multiple
|
||||||
"""Returns an appropriate precision for the given torch device"""
|
# configurations. It is strongly recommended to pass the app_config to this function.
|
||||||
|
def choose_precision(device: torch.device, app_config: Optional[InvokeAIAppConfig] = None) -> str:
|
||||||
|
"""Return an appropriate precision for the given torch device."""
|
||||||
|
app_config = app_config or config
|
||||||
if device.type == "cuda":
|
if device.type == "cuda":
|
||||||
device_name = torch.cuda.get_device_name(device)
|
device_name = torch.cuda.get_device_name(device)
|
||||||
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
|
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
|
||||||
if config.precision == "bfloat16":
|
if app_config.precision == "float32":
|
||||||
|
return "float32"
|
||||||
|
elif app_config.precision == "bfloat16":
|
||||||
return "bfloat16"
|
return "bfloat16"
|
||||||
else:
|
else:
|
||||||
return "float16"
|
return "float16"
|
||||||
@ -43,9 +48,14 @@ def choose_precision(device: torch.device) -> str:
|
|||||||
return "float32"
|
return "float32"
|
||||||
|
|
||||||
|
|
||||||
def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype:
|
# We are in transition here from using a single global AppConfig to allowing multiple
|
||||||
|
# configurations. It is strongly recommended to pass the app_config to this function.
|
||||||
|
def torch_dtype(
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
app_config: Optional[InvokeAIAppConfig] = None,
|
||||||
|
) -> torch.dtype:
|
||||||
device = device or choose_torch_device()
|
device = device or choose_torch_device()
|
||||||
precision = choose_precision(device)
|
precision = choose_precision(device, app_config)
|
||||||
if precision == "float16":
|
if precision == "float16":
|
||||||
return torch.float16
|
return torch.float16
|
||||||
if precision == "bfloat16":
|
if precision == "bfloat16":
|
||||||
|
@ -505,7 +505,7 @@ def list_models(installer: ModelInstallService, model_type: ModelType):
|
|||||||
print(f"Installed models of type `{model_type}`:")
|
print(f"Installed models of type `{model_type}`:")
|
||||||
for model in models:
|
for model in models:
|
||||||
path = (config.models_path / model.path).resolve()
|
path = (config.models_path / model.path).resolve()
|
||||||
print(f"{model.name:40}{model.base.value:14}{path}")
|
print(f"{model.name:40}{model.base.value:5}{model.type.value:8}{model.format.value:12}{path}")
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
|
Loading…
Reference in New Issue
Block a user