From 208d3907793a0197ab483ea2dd3542b2eb946bdf Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 29 Sep 2023 19:23:08 -0400 Subject: [PATCH] almost all type mismatches fixed --- invokeai/app/api/routers/models.py | 11 +- invokeai/app/services/events.py | 2 +- .../app/services/model_manager_service.py | 48 ++++---- invokeai/backend/model_manager/__init__.py | 4 +- invokeai/backend/model_manager/cache.py | 15 ++- invokeai/backend/model_manager/config.py | 3 + .../convert_ckpt_to_diffusers.py | 29 +++-- .../backend/model_manager/download/base.py | 13 +-- .../backend/model_manager/download/queue.py | 56 ++++----- invokeai/backend/model_manager/hash.py | 4 +- invokeai/backend/model_manager/install.py | 59 ++++++---- invokeai/backend/model_manager/loader.py | 6 +- invokeai/backend/model_manager/lora.py | 14 +-- invokeai/backend/model_manager/merge.py | 10 +- invokeai/backend/model_manager/probe.py | 8 +- invokeai/backend/model_manager/search.py | 8 +- .../backend/model_manager/storage/__init__.py | 4 +- .../backend/model_manager/storage/base.py | 24 ++-- .../backend/model_manager/storage/migrate.py | 16 +-- invokeai/backend/model_manager/storage/sql.py | 10 +- .../backend/model_manager/storage/yaml.py | 9 +- invokeai/backend/model_manager/util.py | 2 +- invokeai/backend/util/logging.py | 4 +- invokeai/backend/util/util.py | 108 ------------------ 24 files changed, 185 insertions(+), 282 deletions(-) diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index a4a5d49051..2a3a06e928 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -3,7 +3,7 @@ import pathlib from enum import Enum -from typing import List, Literal, Optional, Union +from typing import List, Literal, Optional, Union, Tuple from fastapi import Body, Path, Query, Response from fastapi.routing import APIRouter @@ -22,14 +22,17 @@ from invokeai.backend.model_manager import ( from invokeai.backend.model_manager.download import DownloadJobStatus, UnknownJobIDException from invokeai.backend.model_manager.merge import MergeInterpolationMethod -from ..dependencies import ApiDependencies +from invokeai.app.api.dependencies import ApiDependencies models_router = APIRouter(prefix="/v1/models", tags=["models"]) # NOTE: The generic configuration classes defined in invokeai.backend.model_manager.config -# such as "MainCheckpointConfig" are repackaged by code original written by Stalker +# such as "MainCheckpointConfig" are repackaged by code originally written by Stalker # into base-specific classes such as `abc.StableDiffusion1ModelCheckpointConfig` # This is the reason for the calls to dict() followed by pydantic.parse_obj_as() + +# There are still numerous mypy errors here because it does not seem to like this +# way of dynamically generating the typing hints below. UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] @@ -38,7 +41,7 @@ ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)] class ModelsList(BaseModel): - models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] + models: List[Union[tuple(OPENAPI_MODEL_CONFIGS)]] class ModelImportStatus(BaseModel): diff --git a/invokeai/app/services/events.py b/invokeai/app/services/events.py index aa5d3a3672..5a7d3a0237 100644 --- a/invokeai/app/services/events.py +++ b/invokeai/app/services/events.py @@ -3,9 +3,9 @@ from typing import Any, Optional from invokeai.app.models.image import ProgressImage -from invokeai.app.services.model_manager_service import ModelInfo, SubModelType from invokeai.app.services.session_queue.session_queue_common import EnqueueBatchResult, SessionQueueItem from invokeai.app.util.misc import get_timestamp +from invokeai.backend.model_manager import ModelInfo, SubModelType from invokeai.backend.model_manager.download import DownloadJobBase from invokeai.backend.util.logging import InvokeAILogger diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index d62b2096ae..2422e7f3f0 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -26,6 +26,7 @@ from invokeai.backend.model_manager import ( from invokeai.backend.model_manager.cache import CacheStats from invokeai.backend.model_manager.download import DownloadJobBase from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger +from .events import EventServiceBase from .config import InvokeAIAppConfig @@ -33,11 +34,6 @@ if TYPE_CHECKING: from ..invocations.baseinvocation import InvocationContext -# "forward declaration" because of circular import issues -class EventServiceBase: - pass - - class ModelManagerServiceBase(ABC): """Responsible for managing models on disk and in memory.""" @@ -197,8 +193,8 @@ class ModelManagerServiceBase(ABC): def install_model( self, source: Union[str, Path, AnyHttpUrl], + priority: int = 10, model_attributes: Optional[Dict[str, Any]] = None, - priority: Optional[int] = 10, ) -> ModelInstallJob: """Import a path, repo_id or URL. Returns an ModelInstallJob. @@ -230,7 +226,7 @@ class ModelManagerServiceBase(ABC): pass @abstractmethod - def wait_for_installs(self) -> Dict[str, str]: + def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]: """ Wait for all pending installs to complete. @@ -334,7 +330,7 @@ class ModelManagerService(ModelManagerServiceBase): """Responsible for managing models on disk and in memory.""" _loader: ModelLoad = Field(description="InvokeAIAppConfig object for the current process") - _event_bus: "EventServiceBase" = Field(description="an event bus to send install events to", default=None) + _event_bus: Optional[EventServiceBase] = Field(description="an event bus to send install events to", default=None) def __init__(self, config: InvokeAIAppConfig, event_bus: Optional["EventServiceBase"] = None): """ @@ -345,8 +341,10 @@ class ModelManagerService(ModelManagerServiceBase): installation and download events will be sent to the event bus. """ self._event_bus = event_bus - handlers = [self._event_bus.emit_model_event] if self._event_bus else None - self._loader = ModelLoad(config, event_handlers=handlers) + kwargs: Dict[str, Any] = {} + if self._event_bus: + kwargs.update(event_handlers=[self._event_bus.emit_model_event]) + self._loader = ModelLoad(config, **kwargs) def get_model( self, @@ -365,8 +363,8 @@ class ModelManagerService(ModelManagerServiceBase): if context: self._emit_load_event( context=context, - key=key, - submodel_type=submodel_type, + model_key=key, + submodel=submodel_type, model_info=model_info, ) @@ -416,16 +414,18 @@ class ModelManagerService(ModelManagerServiceBase): assertion error if the name already exists. """ self.logger.debug(f"add/update model {model_path}") - return self._loader.installer.install( - model_path, - probe_override=model_attributes, + return ModelInstallJob.parse_obj( + self._loader.installer.install( + model_path, + probe_override=model_attributes, + ) ) def install_model( self, source: Union[str, Path, AnyHttpUrl], + priority: int = 10, model_attributes: Optional[Dict[str, Any]] = None, - priority: Optional[int] = 10, ) -> ModelInstallJob: """ Add a model using a path, repo_id or URL. @@ -438,11 +438,13 @@ class ModelManagerService(ModelManagerServiceBase): """ self.logger.debug(f"add model {source}") variant = "fp16" if self._loader.precision == "float16" else None - return self._loader.installer.install( - source, - probe_override=model_attributes, - variant=variant, - priority=priority, + return ModelInstallJob.parse_obj( + self._loader.installer.install( + source, + probe_override=model_attributes, + variant=variant, + priority=priority, + ) ) def list_install_jobs(self) -> List[ModelInstallJob]: @@ -453,9 +455,9 @@ class ModelManagerService(ModelManagerServiceBase): def id_to_job(self, id: int) -> ModelInstallJob: """Return the ModelInstallJob instance corresponding to the given job ID.""" - return self._loader.queue.id_to_job(id) + return ModelInstallJob.parse_obj(self._loader.queue.id_to_job(id)) - def wait_for_installs(self) -> Dict[str, str]: + def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]: """ Wait for all pending installs to complete. diff --git a/invokeai/backend/model_manager/__init__.py b/invokeai/backend/model_manager/__init__.py index 2cf90f2e8b..e96ac1e668 100644 --- a/invokeai/backend/model_manager/__init__.py +++ b/invokeai/backend/model_manager/__init__.py @@ -1,6 +1,4 @@ -""" -Initialization file for invokeai.backend.model_manager.config -""" +"""Initialization file for invokeai.backend.model_manager.config.""" from .config import ( # noqa F401 BaseModelType, InvalidModelConfigException, diff --git a/invokeai/backend/model_manager/cache.py b/invokeai/backend/model_manager/cache.py index 4c357634a9..1d02e7d3bc 100644 --- a/invokeai/backend/model_manager/cache.py +++ b/invokeai/backend/model_manager/cache.py @@ -1,5 +1,6 @@ """ Manage a RAM cache of diffusion/transformer models for fast switching. + They are moved between GPU VRAM and CPU RAM as necessary. If the cache grows larger than a preset maximum, then the least recently used model will be cleared and (re)loaded from disk when next needed. @@ -22,11 +23,11 @@ import sys from contextlib import suppress from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, Optional, Type, Union, types +from typing import Any, Dict, List, Optional, Type, Union import torch -import invokeai.backend.util.logging as logger +from invokeai.backend.util import InvokeAILogger, Logger from ..util import GIG from ..util.devices import choose_torch_device @@ -97,7 +98,7 @@ class ModelCache(object): sequential_offload: bool = False, lazy_offloading: bool = True, sha_chunksize: int = 16777216, - logger: types.ModuleType = logger, + logger: Logger = InvokeAILogger.get_logger(), ): """ :param max_cache_size: Maximum size of the RAM cache [6.0 GB] @@ -122,8 +123,8 @@ class ModelCache(object): # used for stats collection self.stats: Optional[CacheStats] = None - self._cached_models = dict() - self._cache_stack = list() + self._cached_models: Dict[str, _CacheRecord] = dict() + self._cache_stack: List[str] = list() # Note that the combination of model_path and submodel_type # are sufficient to generate a unique cache key. This key @@ -221,8 +222,12 @@ class ModelCache(object): return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size) class ModelLocker(object): + """Context manager that locks models into VRAM.""" + def __init__(self, cache, key, model, gpu_load, size_needed): """ + Initialize a context manager object that locks models into VRAM. + :param cache: The model_cache object :param key: The key of the model to lock in GPU :param model: The model to lock diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 201b270d60..9317a898ba 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -240,6 +240,7 @@ AnyModelConfig = Union[ TextualInversionConfig, ONNXSD1Config, ONNXSD2Config, + ModelConfigBase, ] @@ -324,6 +325,8 @@ class ModelConfigFactory(object): # TO DO: Move this somewhere else class SilenceWarnings(object): + """Context manager to temporarily lower verbosity of diffusers & transformers warning messages.""" + def __init__(self): self.transformers_verbosity = transformers_logging.get_verbosity() self.diffusers_verbosity = diffusers_logging.get_verbosity() diff --git a/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py b/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py index 0a3a63dad6..ed89d78306 100644 --- a/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py @@ -19,9 +19,8 @@ import re from contextlib import nullcontext -from io import BytesIO from pathlib import Path -from typing import Optional, Union +from typing import Dict, Optional, Union import requests import torch @@ -1223,7 +1222,7 @@ def download_from_original_stable_diffusion_ckpt( # scan model scan_result = scan_file_path(checkpoint_path) if scan_result.infected_files != 0: - raise "The model {checkpoint_path} is potentially infected by malware. Aborting import." + raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.") if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" checkpoint = torch.load(checkpoint_path, map_location=device) @@ -1272,15 +1271,15 @@ def download_from_original_stable_diffusion_ckpt( # only refiner xl has embedder and one text embedders config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" - original_config_file = BytesIO(requests.get(config_url).content) + original_config_file = requests.get(config_url).text original_config = OmegaConf.load(original_config_file) - if original_config["model"]["params"].get("use_ema") is not None: - extract_ema = original_config["model"]["params"]["use_ema"] + if original_config.model["params"].get("use_ema") is not None: + extract_ema = original_config.model["params"]["use_ema"] if ( model_version in [BaseModelType.StableDiffusion2, BaseModelType.StableDiffusion1] - and original_config["model"]["params"].get("parameterization") == "v" + and original_config.model["params"].get("parameterization") == "v" ): prediction_type = "v_prediction" upcast_attention = True @@ -1312,11 +1311,11 @@ def download_from_original_stable_diffusion_ckpt( num_in_channels = 4 if "unet_config" in original_config.model.params: - original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + original_config.model["params"]["unet_config"]["params"]["in_channels"] = num_in_channels if ( - "parameterization" in original_config["model"]["params"] - and original_config["model"]["params"]["parameterization"] == "v" + "parameterization" in original_config.model["params"] + and original_config.model["params"]["parameterization"] == "v" ): if prediction_type is None: # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` @@ -1437,7 +1436,7 @@ def download_from_original_stable_diffusion_ckpt( if model_type == "FrozenOpenCLIPEmbedder": config_name = "stabilityai/stable-diffusion-2" - config_kwargs = {"subfolder": "text_encoder"} + config_kwargs: Dict[str, Union[str, int]] = {"subfolder": "text_encoder"} text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs) tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-2-clip", subfolder="tokenizer") @@ -1664,7 +1663,7 @@ def download_controlnet_from_original_ckpt( # scan model scan_result = scan_file_path(checkpoint_path) if scan_result.infected_files != 0: - raise "The model {checkpoint_path} is potentially infected by malware. Aborting import." + raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.") if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" checkpoint = torch.load(checkpoint_path, map_location=device) @@ -1685,7 +1684,7 @@ def download_controlnet_from_original_ckpt( original_config = OmegaConf.load(original_config_file) if num_in_channels is not None: - original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + original_config.model["params"]["unet_config"]["params"]["in_channels"] = num_in_channels if "control_stage_config" not in original_config.model.params: raise ValueError("`control_stage_config` not present in original config") @@ -1725,7 +1724,7 @@ def convert_ckpt_to_diffusers( and in addition a path-like object indicating the location of the desired diffusers model to be written. """ - pipe = download_from_original_stable_diffusion_ckpt(checkpoint_path, **kwargs) + pipe = download_from_original_stable_diffusion_ckpt(str(checkpoint_path), **kwargs) pipe.save_pretrained( dump_path, @@ -1743,6 +1742,6 @@ def convert_controlnet_to_diffusers( and in addition a path-like object indicating the location of the desired diffusers model to be written. """ - pipe = download_controlnet_from_original_ckpt(checkpoint_path, **kwargs) + pipe = download_controlnet_from_original_ckpt(str(checkpoint_path), **kwargs) pipe.save_pretrained(dump_path, safe_serialization=True) diff --git a/invokeai/backend/model_manager/download/base.py b/invokeai/backend/model_manager/download/base.py index 1cebb9dcfa..4465bd49d3 100644 --- a/invokeai/backend/model_manager/download/base.py +++ b/invokeai/backend/model_manager/download/base.py @@ -1,7 +1,5 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team -""" -Abstract base class for a multithreaded model download queue. -""" +"""Abstract base class for a multithreaded model download queue.""" from abc import ABC, abstractmethod from enum import Enum @@ -56,7 +54,7 @@ class DownloadJobBase(BaseModel): priority: int = Field(default=10, description="Queue priority; lower values are higher priority") id: int = Field(description="Numeric ID of this job", default=-1) # default id is a placeholder - source: str = Field(description="URL or repo_id to download") + source: Union[str, Path] = Field(description="URL or repo_id to download") destination: Path = Field(description="Destination of URL on local disk") metadata: ModelSourceMetadata = Field( description="Model metadata (source-specific)", default_factory=ModelSourceMetadata @@ -84,7 +82,8 @@ class DownloadJobBase(BaseModel): def add_event_handler(self, handler: DownloadEventHandler): """Add an event handler to the end of the handlers list.""" - self.event_handlers.append(handler) + if self.event_handlers is not None: + self.event_handlers.append(handler) def clear_event_handlers(self): """Clear all event handlers.""" @@ -226,9 +225,7 @@ class DownloadQueueBase(ABC): @abstractmethod def prune_jobs(self): - """ - Prune completed and errored queue items from the job list. - """ + """Prune completed and errored queue items from the job list.""" pass @abstractmethod diff --git a/invokeai/backend/model_manager/download/queue.py b/invokeai/backend/model_manager/download/queue.py index b6979cdb43..3ca6b130e1 100644 --- a/invokeai/backend/model_manager/download/queue.py +++ b/invokeai/backend/model_manager/download/queue.py @@ -7,19 +7,18 @@ import shutil import threading import time import traceback -from json import JSONDecodeError from pathlib import Path from queue import PriorityQueue from typing import Dict, List, Optional, Set, Tuple, Union import requests from huggingface_hub import HfApi, hf_hub_url -from pydantic import Field, ValidationError, validator +from pydantic import Field, ValidationError, parse_obj_as, validator from pydantic.networks import AnyHttpUrl from requests import HTTPError from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.util.logging import InvokeAILogger +from invokeai.backend.util import InvokeAILogger, Logger from ..storage import DuplicateModelException from . import HTTP_RE, REPO_ID_RE @@ -56,19 +55,20 @@ class DownloadJobRepoID(DownloadJobBase): """Download repo ids.""" variant: Optional[str] = Field(description="Variant, such as 'fp16', to download") + source: str = Field(description="URL or repo_id to download") @validator("source") @classmethod def _validate_source(cls, v: str) -> str: if not re.match(REPO_ID_RE, v): - raise ValidationError(f"{v} invalid repo_id") + raise ValidationError(f"{v} invalid repo_id", cls) return v class DownloadJobPath(DownloadJobBase): """Handle file paths.""" - source: Path = Field(description="Path to a file or directory to install") + source: Union[str, Path] = Field(description="Path to a file or directory to install") class DownloadQueue(DownloadQueueBase): @@ -77,8 +77,8 @@ class DownloadQueue(DownloadQueueBase): _jobs: Dict[int, DownloadJobBase] _worker_pool: Set[threading.Thread] _queue: PriorityQueue - _lock: threading.Lock - _logger: InvokeAILogger + _lock: threading.RLock + _logger: Logger _event_handlers: List[DownloadEventHandler] = Field(default_factory=list) _next_job_id: int = 0 _sequence: int = 0 # This is for debugging and used to tag jobs in dequeueing order @@ -126,6 +126,7 @@ class DownloadQueue(DownloadQueueBase): """Create a download job and return its ID.""" kwargs = dict() + cls = DownloadJobBase if Path(source).exists(): cls = DownloadJobPath elif re.match(REPO_ID_RE, str(source)): @@ -137,7 +138,7 @@ class DownloadQueue(DownloadQueueBase): raise NotImplementedError(f"Don't know what to do with this type of source: {source}") job = cls( - source=source, + source=str(source), destination=Path(destdir) / (filename or "."), access_token=access_token, event_handlers=(event_handlers or self._event_handlers), @@ -179,7 +180,7 @@ class DownloadQueue(DownloadQueueBase): def list_jobs(self) -> List[DownloadJobBase]: """List all the jobs.""" - return self._jobs.values() + return list(self._jobs.values()) def change_priority(self, job: DownloadJobBase, delta: int): """Change the priority of a job. Smaller priorities run first.""" @@ -193,9 +194,7 @@ class DownloadQueue(DownloadQueueBase): self._lock.release() def prune_jobs(self): - """ - Prune completed and errored queue items from the job list. - """ + """Prune completed and errored queue items from the job list.""" try: to_delete = set() self._lock.acquire() @@ -334,7 +333,7 @@ class DownloadQueue(DownloadQueueBase): elif isinstance(job, DownloadJobPath): self._download_path(job) else: - raise NotImplementedError(f"Don't know what to do with this job: {job}") + raise NotImplementedError(f"Don't know what to do with this job: {job}, type={type(job)}") if job.status == DownloadJobStatus.CANCELLED: self._cleanup_cancelled_job(job) @@ -354,7 +353,7 @@ class DownloadQueue(DownloadQueueBase): model = None # a Civitai download URL - if match := re.match(CIVITAI_MODEL_DOWNLOAD, metadata_url): + if match := re.match(CIVITAI_MODEL_DOWNLOAD, str(metadata_url)): version = match.group(1) resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json() metadata.thumbnail_url = metadata.thumbnail_url or resp["images"][0]["url"] @@ -364,16 +363,16 @@ class DownloadQueue(DownloadQueueBase): metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"]) + f"?modelVersionId={version}" # a Civitai model page with the version - if match := re.match(CIVITAI_MODEL_PAGE_WITH_VERSION, metadata_url): + if match := re.match(CIVITAI_MODEL_PAGE_WITH_VERSION, str(metadata_url)): model = match.group(1) version = int(match.group(2)) # and without - elif match := re.match(CIVITAI_MODEL_PAGE + r"(\d+)", metadata_url): + elif match := re.match(CIVITAI_MODEL_PAGE + r"(\d+)", str(metadata_url)): model = match.group(1) version = None if not model: - return url + return parse_obj_as(AnyHttpUrl, url) if model: resp = self._requests.get(CIVITAI_MODELS_ENDPOINT + str(model)).json() @@ -419,16 +418,19 @@ class DownloadQueue(DownloadQueueBase): if job.destination.is_dir(): try: - file_name = re.search('filename="(.+)"', resp.headers["Content-Disposition"]).group(1) + file_name = "" + if match := re.search('filename="(.+)"', resp.headers["Content-Disposition"]): + file_name = match.group(1) + assert file_name != "" self._validate_filename( - job.destination, file_name + job.destination.as_posix(), file_name ) # will raise a ValueError exception if file_name is suspicious except ValueError: self._logger.warning( f"Invalid filename '{file_name}' returned by source {url}, using last component of URL instead" ) file_name = os.path.basename(url) - except KeyError: + except (KeyError, AssertionError): file_name = os.path.basename(url) job.destination = job.destination / file_name dest = job.destination @@ -537,12 +539,13 @@ class DownloadQueue(DownloadQueueBase): self._update_job_status(job, DownloadJobStatus.RUNNING) return - job.subqueue = self.__class__( + subqueue = self.__class__( event_handlers=[subdownload_event], requests_session=self._requests, quiet=True, ) try: + job = DownloadJobRepoID.parse_obj(job) repo_id = job.source variant = job.variant if not job.metadata: @@ -550,12 +553,12 @@ class DownloadQueue(DownloadQueueBase): urls_to_download = self._get_repo_info(repo_id, variant=variant, metadata=job.metadata) if job.destination.name != Path(repo_id).name: job.destination = job.destination / Path(repo_id).name - bytes_downloaded = dict() + bytes_downloaded: Dict[int, int] = dict() job.total_bytes = 0 for url, subdir, file, size in urls_to_download: job.total_bytes += size - job.subqueue.create_download_job( + subqueue.create_download_job( source=url, destdir=job.destination / subdir, filename=file, @@ -569,6 +572,7 @@ class DownloadQueue(DownloadQueueBase): self._update_job_status(job, DownloadJobStatus.ERROR) self._logger.error(job.error) finally: + job.subqueue = subqueue job.subqueue.join() if job.status == DownloadJobStatus.RUNNING: self._update_job_status(job, DownloadJobStatus.COMPLETED) @@ -579,7 +583,7 @@ class DownloadQueue(DownloadQueueBase): repo_id: str, metadata: ModelSourceMetadata, variant: Optional[str] = None, - ) -> Tuple[List[Tuple[AnyHttpUrl, Path, Path]], ModelSourceMetadata]: + ) -> List[Tuple[AnyHttpUrl, Path, Path, int]]: """ Given a repo_id and an optional variant, return list of URLs to download to get the model. The metadata field will be updated with model metadata from HuggingFace. @@ -602,7 +606,7 @@ class DownloadQueue(DownloadQueueBase): paths = [x for x in paths if Path(x).parent.as_posix() in submodels] paths.insert(0, "model_index.json") urls = [ - (hf_hub_url(repo_id, filename=x.as_posix()), x.parent or Path("."), x.name, sizes[x.as_posix()]) + (hf_hub_url(repo_id, filename=x.as_posix()), x.parent or Path("."), Path(x.name), sizes[x.as_posix()]) for x in self._select_variants(paths, variant) ] if hasattr(model_info, "cardData"): @@ -614,7 +618,7 @@ class DownloadQueue(DownloadQueueBase): def _select_variants(self, paths: List[str], variant: Optional[str] = None) -> Set[Path]: """Select the proper variant files from a list of HuggingFace repo_id paths.""" result = set() - basenames = dict() + basenames: Dict[Path, Path] = dict() for p in paths: path = Path(p) diff --git a/invokeai/backend/model_manager/hash.py b/invokeai/backend/model_manager/hash.py index 80473f8b8b..c7ebe2628b 100644 --- a/invokeai/backend/model_manager/hash.py +++ b/invokeai/backend/model_manager/hash.py @@ -57,9 +57,9 @@ class FastModelHash(object): # only tally tensor files if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")): continue - path = Path(root) / file + path = (Path(root) / file).as_posix() fast_hash = cls._hash_file(path) - components.update({str(path): fast_hash}) + components.update({path: fast_hash}) # hash all the model hashes together, using alphabetic file order md5 = hashlib.md5() diff --git a/invokeai/backend/model_manager/install.py b/invokeai/backend/model_manager/install.py index 9e0e849fa0..8d425043f4 100644 --- a/invokeai/backend/model_manager/install.py +++ b/invokeai/backend/model_manager/install.py @@ -54,7 +54,7 @@ import tempfile from abc import ABC, abstractmethod from pathlib import Path from shutil import move, rmtree -from typing import Any, Callable, Dict, List, Optional, Set, Union +from typing import Any, Callable, Dict, List, Optional, Set, Type, Union from pydantic import Field from pydantic.networks import AnyHttpUrl @@ -166,7 +166,7 @@ class ModelInstallBase(ABC): pass @abstractmethod - def install_path(self, model_path: Union[Path, str], info: Optional[ModelProbeInfo] = None) -> str: + def install_path(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> str: """ Probe, register and install the model in the models directory. @@ -174,7 +174,7 @@ class ModelInstallBase(ABC): the models directory handled by InvokeAI. :param model_path: Filesystem Path to the model. - :param info: Optional ModelProbeInfo object. If not provided, model will be probed. + :param overrides: Dictionary of model probe info fields that, if present, override probed values. :returns id: The string ID of the installed model. """ pass @@ -184,6 +184,7 @@ class ModelInstallBase(ABC): self, source: Union[str, Path, AnyHttpUrl], inplace: bool = True, + priority: int = 10, variant: Optional[str] = None, probe_override: Optional[Dict[str, Any]] = None, metadata: Optional[ModelSourceMetadata] = None, @@ -225,7 +226,7 @@ class ModelInstallBase(ABC): pass @abstractmethod - def wait_for_installs(self) -> Dict[str, str]: + def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]: """ Wait for all pending installs to complete. @@ -293,7 +294,7 @@ class ModelInstallBase(ABC): pass @abstractmethod - def sync_model_path(self, key) -> Path: + def sync_model_path(self, key) -> ModelConfigBase: """ Move model into the location indicated by its basetype, type and name. @@ -324,11 +325,11 @@ class ModelInstall(ModelInstallBase): _logger: Logger _store: ModelConfigStore _download_queue: DownloadQueueBase - _async_installs: Dict[str, str] - _installed: Set[Path] = Field(default=set) + _async_installs: Dict[Union[str, Path, AnyHttpUrl], Optional[str]] + _installed: Set[str] = Field(default=set) _tmpdir: Optional[tempfile.TemporaryDirectory] # used for downloads - _legacy_configs = { + _legacy_configs: Dict[BaseModelType, Dict[ModelVariantType, Union[str, dict]]] = { BaseModelType.StableDiffusion1: { ModelVariantType.Normal: "v1-inference.yaml", ModelVariantType.Inpaint: "v1-inpainting-inference.yaml", @@ -357,13 +358,13 @@ class ModelInstall(ModelInstallBase): config: Optional[InvokeAIAppConfig] = None, logger: Optional[Logger] = None, download: Optional[DownloadQueueBase] = None, - event_handlers: Optional[List[DownloadEventHandler]] = None, + event_handlers: List[DownloadEventHandler] = [], ): # noqa D107 - use base class docstrings self._app_config = config or InvokeAIAppConfig.get_config() self._logger = logger or InvokeAILogger.get_logger(config=self._app_config) self._store = store or get_config_store(self._app_config.model_conf_path) self._download_queue = download or DownloadQueue(config=self._app_config, event_handlers=event_handlers) - self._async_installs = dict() + self._async_installs: Dict[Union[str, Path, AnyHttpUrl], Union[str, None]] = dict() self._installed = set() self._tmpdir = None @@ -403,7 +404,8 @@ class ModelInstall(ModelInstallBase): ) # add 'main' specific fields if info.model_type == ModelType.Main: - registration_data.update(variant=info.variant_type) + if info.variant_type: + registration_data.update(variant=info.variant_type) if info.format == ModelFormat.Checkpoint: try: config_file = self._legacy_configs[info.base_type][info.variant_type] @@ -416,7 +418,7 @@ class ModelInstall(ModelInstallBase): ) config_file = config_file[SchedulerPredictionType.VPrediction] registration_data.update( - config=Path(self._app_config.legacy_conf_dir, config_file).as_posix(), + config=Path(self._app_config.legacy_conf_dir, str(config_file)).as_posix(), ) except KeyError as exc: raise InvalidModelException( @@ -453,12 +455,12 @@ class ModelInstall(ModelInstallBase): return old_path.replace(new_path) def _probe_model(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> ModelProbeInfo: - info: ModelProbeInfo = ModelProbe.probe(model_path) + info: ModelProbeInfo = ModelProbe.probe(Path(model_path)) if overrides: # used to override probe fields for key, value in overrides.items(): try: setattr(info, key, value) # skip validation errors - except: + except Exception: pass return info @@ -488,11 +490,11 @@ class ModelInstall(ModelInstallBase): self, source: Union[str, Path, AnyHttpUrl], inplace: bool = True, + priority: int = 10, variant: Optional[str] = None, probe_override: Optional[Dict[str, Any]] = None, metadata: Optional[ModelSourceMetadata] = None, access_token: Optional[str] = None, - priority: Optional[int] = 10, ) -> DownloadJobBase: # noqa D102 queue = self._download_queue @@ -502,7 +504,8 @@ class ModelInstall(ModelInstallBase): if inplace and Path(source).exists() else self._complete_installation_handler ) - job.probe_override = probe_override + if isinstance(job, ModelInstallJob): + job.probe_override = probe_override if metadata: job.metadata = metadata job.add_event_handler(handler) @@ -512,6 +515,7 @@ class ModelInstall(ModelInstallBase): return job def _complete_installation_handler(self, job: DownloadJobBase): + job = ModelInstallJob.parse_obj(job) # this upcast should succeed if job.status == "completed": self._logger.info(f"{job.source}: Download finished with status {job.status}. Installing.") model_id = self.install_path(job.destination, job.probe_override) @@ -537,6 +541,7 @@ class ModelInstall(ModelInstallBase): self._tmpdir = None def _complete_registration_handler(self, job: DownloadJobBase): + job = ModelInstallJob.parse_obj(job) # upcast should succeed if job.status == "completed": self._logger.info(f"{job.source}: Installing in place.") model_id = self.register_path(job.destination, job.probe_override) @@ -567,7 +572,7 @@ class ModelInstall(ModelInstallBase): models_dir = self._app_config.models_path if not old_path.is_relative_to(models_dir): - return old_path + return model new_path = models_dir / model.base_model.value / model.model_type.value / model.name self._logger.info( @@ -592,7 +597,6 @@ class ModelInstall(ModelInstallBase): # we simply probe and register/install it. The job does not actually do anything, but we # create one anyway in order to have similar behavior for local files, URLs and repo_ids. if Path(source).exists(): # a path that is already on disk - source = Path(source) destdir = source return ModelInstallPathJob(source=source, destination=Path(destdir)) @@ -600,6 +604,7 @@ class ModelInstall(ModelInstallBase): models_dir = self._app_config.models_path self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir) + cls = ModelInstallJob if re.match(REPO_ID_RE, str(source)): cls = ModelInstallRepoIDJob kwargs = dict(variant=variant) @@ -609,10 +614,15 @@ class ModelInstall(ModelInstallBase): else: raise ValueError(f"'{source}' is not recognized as a local file, directory, repo_id or URL") return cls( - source=source, destination=Path(self._tmpdir.name), access_token=access_token, priority=priority, **kwargs + source=str(source), + destination=Path(self._tmpdir.name), + access_token=access_token, + priority=priority, + **kwargs, ) - def wait_for_installs(self) -> Dict[str, str]: # noqa D102 + def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]: + """Pause until all installation jobs have completed.""" self._download_queue.join() id_map = self._async_installs self._async_installs = dict() @@ -634,9 +644,10 @@ class ModelInstall(ModelInstallBase): dest_directory: Optional[Path] = None, ) -> ModelConfigBase: """ - Convert a checkpoint file into a diffusers folder, deleting the cached - version and deleting the original checkpoint file if it is in the models - directory. + Convert a checkpoint file into a diffusers folder. + + It will delete the cached version ans well as the + original checkpoint file if it is in the models directory. :param key: Unique key of model. :dest_directory: Optional place to put converted file. If not specified, will be stored in the `models_dir`. @@ -668,7 +679,7 @@ class ModelInstall(ModelInstallBase): update = info.dict() update.pop("config") update["model_format"] = "diffusers" - update["path"] = converted_model.location.as_posix() + update["path"] = converted_model.location if dest_directory: new_diffusers_path = Path(dest_directory) / info.name diff --git a/invokeai/backend/model_manager/loader.py b/invokeai/backend/model_manager/loader.py index a7560e9567..39e836559c 100644 --- a/invokeai/backend/model_manager/loader.py +++ b/invokeai/backend/model_manager/loader.py @@ -121,11 +121,7 @@ class ModelLoad(ModelLoadBase): _cache_keys: dict _models_file: Path - def __init__( - self, - config: InvokeAIAppConfig, - event_handlers: Optional[List[DownloadEventHandler]] = None, - ): + def __init__(self, config: InvokeAIAppConfig, event_handlers: List[DownloadEventHandler] = []): """ Initialize ModelLoad object. diff --git a/invokeai/backend/model_manager/lora.py b/invokeai/backend/model_manager/lora.py index bb44455c88..602d7f4638 100644 --- a/invokeai/backend/model_manager/lora.py +++ b/invokeai/backend/model_manager/lora.py @@ -12,7 +12,7 @@ from diffusers.models import UNet2DConditionModel from safetensors.torch import load_file from transformers import CLIPTextModel, CLIPTokenizer -from .models.lora import LoRAModel +from .models.lora import LoRALayerBase, LoRAModel, LoRAModelRaw """ loras = [ @@ -87,7 +87,7 @@ class ModelPatcher: def apply_lora_text_encoder( cls, text_encoder: CLIPTextModel, - loras: List[Tuple[LoRAModel, float]], + loras: List[Tuple[LoRAModelRaw, float]], ): with cls.apply_lora(text_encoder, loras, "lora_te_"): yield @@ -97,7 +97,7 @@ class ModelPatcher: def apply_sdxl_lora_text_encoder( cls, text_encoder: CLIPTextModel, - loras: List[Tuple[LoRAModel, float]], + loras: List[Tuple[LoRAModelRaw, float]], ): with cls.apply_lora(text_encoder, loras, "lora_te1_"): yield @@ -107,7 +107,7 @@ class ModelPatcher: def apply_sdxl_lora_text_encoder2( cls, text_encoder: CLIPTextModel, - loras: List[Tuple[LoRAModel, float]], + loras: List[Tuple[LoRAModelRaw, float]], ): with cls.apply_lora(text_encoder, loras, "lora_te2_"): yield @@ -117,7 +117,7 @@ class ModelPatcher: def apply_lora( cls, model: torch.nn.Module, - loras: List[Tuple[LoRAModel, float]], + loras: List[Tuple[LoRAModelRaw, float]], prefix: str, ): original_weights = dict() @@ -337,7 +337,7 @@ class ONNXModelPatcher: def apply_lora( cls, model: IAIOnnxRuntimeModel, - loras: List[Tuple[LoRAModel, float]], + loras: List[Tuple[LoRAModelRaw, torch.Tensor]], prefix: str, ): from .models.base import IAIOnnxRuntimeModel @@ -348,7 +348,7 @@ class ONNXModelPatcher: orig_weights = dict() try: - blended_loras = dict() + blended_loras: Dict[str, torch.Tensor] = dict() for lora, lora_weight in loras: for layer_key, layer in lora.layers.items(): diff --git a/invokeai/backend/model_manager/merge.py b/invokeai/backend/model_manager/merge.py index 3b17f5a2ca..c0d8c90b3c 100644 --- a/invokeai/backend/model_manager/merge.py +++ b/invokeai/backend/model_manager/merge.py @@ -9,7 +9,7 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team import warnings from enum import Enum from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Set from diffusers import DiffusionPipeline from diffusers import logging as dlogging @@ -17,7 +17,8 @@ from diffusers import logging as dlogging import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig -from . import ModelConfigBase, ModelConfigStore, ModelInstall, ModelType +from . import BaseModelType, ModelConfigBase, ModelConfigStore, ModelInstall, ModelType +from .config import MainConfig class MergeInterpolationMethod(str, Enum): @@ -102,11 +103,11 @@ class ModelMerger(object): **kwargs - the default DiffusionPipeline.get_config_dict kwargs: cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map """ - model_paths = list() + model_paths: List[Path] = list() model_names = list() config = self._config store = self._store - base_models = set() + base_models: Set[BaseModelType] = set() vae = None assert ( @@ -115,6 +116,7 @@ class ModelMerger(object): for key in model_keys: info = store.get_model(key) + assert isinstance(info, MainConfig) model_names.append(info.name) assert ( info.model_format == "diffusers" diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 43d51e6b96..858741d2cd 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -33,7 +33,7 @@ class ModelProbeInfo(BaseModel): base_type: BaseModelType format: ModelFormat hash: str - variant_type: Optional[ModelVariantType] = ModelVariantType("normal") + variant_type: ModelVariantType = ModelVariantType("normal") prediction_type: Optional[SchedulerPredictionType] = SchedulerPredictionType("v_prediction") upcast_attention: Optional[bool] = False image_size: Optional[int] = None @@ -114,7 +114,7 @@ class ModelProbe(ModelProbeBase): cls, model_path: Path, prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, - ) -> Optional[ModelProbeInfo]: + ) -> ModelProbeInfo: """Probe model.""" try: model_type = ( @@ -129,7 +129,7 @@ class ModelProbe(ModelProbeBase): probe_class = cls.PROBES[format_type].get(model_type) if not probe_class: - return None + raise InvalidModelException(f"Unable to determine model type for {model_path}") probe = probe_class(model_path, prediction_type_helper) @@ -160,7 +160,7 @@ class ModelProbe(ModelProbeBase): else 512, ) except Exception: - raise + raise InvalidModelException(f"Unable to determine model type for {model_path}") return model_info diff --git a/invokeai/backend/model_manager/search.py b/invokeai/backend/model_manager/search.py index 21f5d0d66f..0670ab2138 100644 --- a/invokeai/backend/model_manager/search.py +++ b/invokeai/backend/model_manager/search.py @@ -27,7 +27,7 @@ from typing import Callable, Optional, Set, Union from pydantic import BaseModel, Field -from invokeai.backend.util.logging import InvokeAILogger +from invokeai.backend.util import InvokeAILogger, Logger default_logger = InvokeAILogger.get_logger() @@ -56,7 +56,7 @@ class ModelSearchBase(ABC, BaseModel): on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221 on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221 stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221 - logger : InvokeAILogger = Field(default=default_logger, description="InvokeAILogger instance.") # noqa E221 + logger : Logger = Field(default=default_logger, description="Logger instance.") # noqa E221 # fmt: on class Config: @@ -143,7 +143,7 @@ class ModelSearch(ModelSearchBase): self.on_search_completed(self._models_found) def search(self, directory: Union[Path, str]) -> Set[Path]: - self._directory = directory + self._directory = Path(directory) self.stats = SearchStats() # zero out self.search_started() # This will initialize _models_found to empty self._walk_directory(directory) @@ -155,7 +155,7 @@ class ModelSearch(ModelSearchBase): # don't descend into directories that start with a "." # to avoid the Mac .DS_STORE issue. if str(Path(root).name).startswith("."): - self._pruned_paths.add(root) + self._pruned_paths.add(Path(root)) if any([Path(root).is_relative_to(x) for x in self._pruned_paths]): continue diff --git a/invokeai/backend/model_manager/storage/__init__.py b/invokeai/backend/model_manager/storage/__init__.py index 4280721cd6..51675c3c4c 100644 --- a/invokeai/backend/model_manager/storage/__init__.py +++ b/invokeai/backend/model_manager/storage/__init__.py @@ -1,6 +1,4 @@ -""" -Initialization file for invokeai.backend.model_manager.storage -""" +"""Initialization file for invokeai.backend.model_manager.storage.""" import pathlib from .base import ( # noqa F401 diff --git a/invokeai/backend/model_manager/storage/base.py b/invokeai/backend/model_manager/storage/base.py index 5a85dc2530..d7585957aa 100644 --- a/invokeai/backend/model_manager/storage/base.py +++ b/invokeai/backend/model_manager/storage/base.py @@ -35,13 +35,11 @@ class ModelConfigStore(ABC): @property @abstractmethod def version(self) -> str: - """ - Return the config file/database schema version. - """ + """Return the config file/database schema version.""" pass @abstractmethod - def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None: + def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> None: """ Add a model to the database. @@ -65,7 +63,7 @@ class ModelConfigStore(ABC): pass @abstractmethod - def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase: + def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig: """ Update the model, returning the updated version. @@ -96,7 +94,7 @@ class ModelConfigStore(ABC): pass @abstractmethod - def search_by_tag(self, tags: Set[str]) -> List[ModelConfigBase]: + def search_by_tag(self, tags: Set[str]) -> List[AnyModelConfig]: """ Return models containing all of the listed tags. @@ -108,10 +106,8 @@ class ModelConfigStore(ABC): def search_by_path( self, path: Union[str, Path], - ) -> Optional[ModelConfigBase]: - """ - Return the model having the indicated path. - """ + ) -> Optional[AnyModelConfig]: + """Return the model having the indicated path.""" pass @abstractmethod @@ -120,7 +116,7 @@ class ModelConfigStore(ABC): model_name: Optional[str] = None, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None, - ) -> List[ModelConfigBase]: + ) -> List[AnyModelConfig]: """ Return models matching name, base and/or type. @@ -133,8 +129,6 @@ class ModelConfigStore(ABC): """ pass - def all_models(self) -> List[ModelConfigBase]: - """ - Return all the model configs in the database. - """ + def all_models(self) -> List[AnyModelConfig]: + """Return all the model configs in the database.""" return self.search_by_name() diff --git a/invokeai/backend/model_manager/storage/migrate.py b/invokeai/backend/model_manager/storage/migrate.py index 8f9c6b2f1e..8500d4ab97 100644 --- a/invokeai/backend/model_manager/storage/migrate.py +++ b/invokeai/backend/model_manager/storage/migrate.py @@ -52,14 +52,14 @@ def migrate_models_store(config: InvokeAIAppConfig): except Exception as excp: print(str(excp)) - model_info = store.get_model(new_key) - if vae := stanza.get("vae") and isinstance(model_info, MainConfig): - model_info.vae = (app_config.models_path / vae).as_posix() - if model_config := stanza.get("config") and isinstance(model_info, MainCheckpointConfig): - model_info.config = (app_config.root_path / model_config).as_posix() - model_info.description = stanza.get("description") - store.update_model(new_key, model_info) - store.update_model(new_key, model_info) + if new_key != "": + model_info = store.get_model(new_key) + if (vae := stanza.get("vae")) and isinstance(model_info, MainConfig): + model_info.vae = (app_config.models_path / vae).as_posix() + if (model_config := stanza.get("config")) and isinstance(model_info, MainCheckpointConfig): + model_info.config = (app_config.root_path / model_config).as_posix() + model_info.description = stanza.get("description") + store.update_model(new_key, model_info) logger.info(f"Original version of models config file saved as {str(old_file) + '.orig'}") shutil.move(old_file, str(old_file) + ".orig") diff --git a/invokeai/backend/model_manager/storage/sql.py b/invokeai/backend/model_manager/storage/sql.py index f692c3214e..93328979ca 100644 --- a/invokeai/backend/model_manager/storage/sql.py +++ b/invokeai/backend/model_manager/storage/sql.py @@ -398,14 +398,14 @@ class ModelConfigStoreSQL(ModelConfigStore): self._lock.release() return count > 0 - def search_by_tag(self, tags: Set[str]) -> List[ModelConfigBase]: + def search_by_tag(self, tags: Set[str]) -> List[AnyModelConfig]: """Return models containing all of the listed tags.""" # rather than create a hairy SQL cross-product, we intersect # tag results in a stepwise fashion at the python level. results = [] try: self._lock.acquire() - matches = set() + matches: Set[str] = set() for tag in tags: self._cursor.execute( """--sql @@ -438,7 +438,7 @@ class ModelConfigStoreSQL(ModelConfigStore): model_name: Optional[str] = None, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None, - ) -> List[ModelConfigBase]: + ) -> List[AnyModelConfig]: """ Return models matching name, base and/or type. @@ -479,7 +479,5 @@ class ModelConfigStoreSQL(ModelConfigStore): return results def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]: - """ - Return the model with the indicated path, or None.. - """ + """Return the model with the indicated path, or None.""" raise NotImplementedError("search_by_path not implemented in storage.sql") diff --git a/invokeai/backend/model_manager/storage/yaml.py b/invokeai/backend/model_manager/storage/yaml.py index cf32e30c88..5a4a42a250 100644 --- a/invokeai/backend/model_manager/storage/yaml.py +++ b/invokeai/backend/model_manager/storage/yaml.py @@ -48,7 +48,6 @@ from typing import List, Optional, Set, Union import yaml from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig -from omegaconf.listconfig import ListConfig from ..config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType from .base import ( @@ -64,7 +63,7 @@ class ModelConfigStoreYAML(ModelConfigStore): """Implementation of the ModelConfigStore ABC using a YAML file.""" _filename: Path - _config: Union[DictConfig, ListConfig] + _config: DictConfig _lock: threading.RLock def __init__(self, config_file: Path): @@ -74,7 +73,9 @@ class ModelConfigStoreYAML(ModelConfigStore): self._lock = threading.RLock() if not self._filename.exists(): self._initialize_yaml() - self._config = OmegaConf.load(self._filename) + config = OmegaConf.load(self._filename) + assert isinstance(config, DictConfig) + self._config = config if str(self.version) != CONFIG_FILE_VERSION: raise ConfigFileVersionMismatchException @@ -101,7 +102,7 @@ class ModelConfigStoreYAML(ModelConfigStore): @property def version(self) -> str: """Return version of this config file/database.""" - return self._config["__metadata__"].get("version") + return self._config.__metadata__.get("version") def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None: """ diff --git a/invokeai/backend/model_manager/util.py b/invokeai/backend/model_manager/util.py index cb4a49d4b7..34ad6e92d9 100644 --- a/invokeai/backend/model_manager/util.py +++ b/invokeai/backend/model_manager/util.py @@ -149,7 +149,7 @@ def _fast_safetensors_reader(path: str): def read_checkpoint_meta(path: Union[str, Path], scan: bool = False): if str(path).endswith(".safetensors"): try: - checkpoint = _fast_safetensors_reader(path) + checkpoint = _fast_safetensors_reader(str(path)) except Exception: # TODO: create issue for support "meta"? checkpoint = safetensors.torch.load_file(path, device="cpu") diff --git a/invokeai/backend/util/logging.py b/invokeai/backend/util/logging.py index 5bc0d5eb80..8d763f8112 100644 --- a/invokeai/backend/util/logging.py +++ b/invokeai/backend/util/logging.py @@ -294,7 +294,7 @@ class InvokeAILegacyLogFormatter(InvokeAIFormatter): } def log_fmt(self, levelno: int) -> str: - return self.FORMATS.get(levelno) + return self.FORMATS[levelno] class InvokeAIPlainLogFormatter(InvokeAIFormatter): @@ -333,7 +333,7 @@ class InvokeAIColorLogFormatter(InvokeAIFormatter): } def log_fmt(self, levelno: int) -> str: - return self.FORMATS.get(levelno) + return self.FORMATS[levelno] LOG_FORMATTERS = { diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py index 9e9f5e8bc8..2e6031a2c8 100644 --- a/invokeai/backend/util/util.py +++ b/invokeai/backend/util/util.py @@ -104,114 +104,6 @@ def get_obj_from_str(string, reload=False): return getattr(importlib.import_module(module, package=None), cls) -# DEAD CODE? -def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): - # create dummy dataset instance - - # run prefetching - if idx_to_fn: - res = func(data, worker_id=idx) - else: - res = func(data) - Q.put([idx, res]) - Q.put("Done") - - -# DEAD CODE? -def parallel_data_prefetch( - func: callable, - data, - n_proc, - target_data_type="ndarray", - cpu_intensive=True, - use_worker_id=False, -): - # if target_data_type not in ["ndarray", "list"]: - # raise ValueError( - # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." - # ) - if isinstance(data, np.ndarray) and target_data_type == "list": - raise ValueError("list expected but function got ndarray.") - elif isinstance(data, abc.Iterable): - if isinstance(data, dict): - logger.warning( - '"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' - ) - data = list(data.values()) - if target_data_type == "ndarray": - data = np.asarray(data) - else: - data = list(data) - else: - raise TypeError( - f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." - ) - - if cpu_intensive: - Q = mp.Queue(1000) - proc = mp.Process - else: - Q = Queue(1000) - proc = Thread - # spawn processes - if target_data_type == "ndarray": - arguments = [[func, Q, part, i, use_worker_id] for i, part in enumerate(np.array_split(data, n_proc))] - else: - step = int(len(data) / n_proc + 1) if len(data) % n_proc != 0 else int(len(data) / n_proc) - arguments = [ - [func, Q, part, i, use_worker_id] - for i, part in enumerate([data[i : i + step] for i in range(0, len(data), step)]) - ] - processes = [] - for i in range(n_proc): - p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) - processes += [p] - - # start processes - logger.info("Start prefetching...") - import time - - start = time.time() - gather_res = [[] for _ in range(n_proc)] - try: - for p in processes: - p.start() - - k = 0 - while k < n_proc: - # get result - res = Q.get() - if res == "Done": - k += 1 - else: - gather_res[res[0]] = res[1] - - except Exception as e: - logger.error("Exception: ", e) - for p in processes: - p.terminate() - - raise e - finally: - for p in processes: - p.join() - logger.info(f"Prefetching complete. [{time.time() - start} sec.]") - - if target_data_type == "ndarray": - if not isinstance(gather_res[0], np.ndarray): - return np.concatenate([np.asarray(r) for r in gather_res], axis=0) - - # order outputs - return np.concatenate(gather_res, axis=0) - elif target_data_type == "list": - out = [] - for r in gather_res: - out.extend(r) - return out - else: - return gather_res - - def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3): delta = (res[0] / shape[0], res[1] / shape[1]) d = (shape[0] // res[0], shape[1] // res[1])