diff --git a/invokeai/app/api/routers/download_queue.py b/invokeai/app/api/routers/download_queue.py index 2dba376c18..a6e53c7a5c 100644 --- a/invokeai/app/api/routers/download_queue.py +++ b/invokeai/app/api/routers/download_queue.py @@ -36,7 +36,7 @@ async def list_downloads() -> List[DownloadJob]: 400: {"description": "Bad request"}, }, ) -async def prune_downloads(): +async def prune_downloads() -> Response: """Prune completed and errored jobs.""" queue = ApiDependencies.invoker.services.download_queue queue.prune_jobs() @@ -87,7 +87,7 @@ async def get_download_job( ) async def cancel_download_job( id: int = Path(description="ID of the download job to cancel."), -): +) -> Response: """Cancel a download job using its ID.""" try: queue = ApiDependencies.invoker.services.download_queue @@ -105,7 +105,7 @@ async def cancel_download_job( 204: {"description": "Download jobs have been cancelled"}, }, ) -async def cancel_all_download_jobs(): +async def cancel_all_download_jobs() -> Response: """Cancel all download jobs.""" ApiDependencies.invoker.services.download_queue.cancel_all_jobs() return Response(status_code=204) diff --git a/invokeai/app/api/routers/model_manager_v2.py b/invokeai/app/api/routers/model_manager_v2.py index 8d31c6f286..029c620707 100644 --- a/invokeai/app/api/routers/model_manager_v2.py +++ b/invokeai/app/api/routers/model_manager_v2.py @@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Set from fastapi import Body, Path, Query, Response from fastapi.routing import APIRouter -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from starlette.exceptions import HTTPException from typing_extensions import Annotated @@ -37,6 +37,35 @@ from ..dependencies import ApiDependencies model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"]) +example_model_output = { + "path": "sd-1/main/openjourney", + "name": "openjourney", + "base": "sd-1", + "type": "main", + "format": "diffusers", + "key": "3a0e45ff858926fd4a63da630688b1e1", + "original_hash": "1c12f18fb6e403baef26fb9d720fbd2f", + "current_hash": "1c12f18fb6e403baef26fb9d720fbd2f", + "description": "sd-1 main model openjourney", + "source": "/opt/invokeai/models/sd-1/main/openjourney", + "last_modified": 1707794711, + "vae": "/opt/invokeai/models/sd-1/vae/vae-ft-mse-840000-ema-pruned_fp16.safetensors", + "variant": "normal", + "prediction_type": "epsilon", + "repo_variant": "fp16", +} + +example_model_input = { + "path": "base/type/name", + "name": "model_name", + "base": "sd-1", + "type": "main", + "format": "diffusers", + "description": "Model description", + "vae": None, + "variant": "normal", +} + class ModelsList(BaseModel): """Return list of configs.""" @@ -88,7 +117,10 @@ async def list_model_records( "/i/{key}", operation_id="get_model_record", responses={ - 200: {"description": "Success"}, + 200: { + "description": "The model configuration was retrieved successfully", + "content": {"application/json": {"example": example_model_output}}, + }, 400: {"description": "Bad request"}, 404: {"description": "The model could not be found"}, }, @@ -165,18 +197,22 @@ async def search_by_metadata_tags( "/i/{key}", operation_id="update_model_record", responses={ - 200: {"description": "The model was updated successfully"}, + 200: { + "description": "The model was updated successfully", + "content": {"application/json": {"example": example_model_output}}, + }, 400: {"description": "Bad request"}, 404: {"description": "The model could not be found"}, 409: {"description": "There is already a model corresponding to the new name"}, }, status_code=200, - response_model=AnyModelConfig, ) async def update_model_record( key: Annotated[str, Path(description="Unique key of model")], - info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")], -) -> AnyModelConfig: + info: Annotated[ + AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input) + ], +) -> Annotated[AnyModelConfig, Field(example="this is neat")]: """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" logger = ApiDependencies.invoker.services.logger record_store = ApiDependencies.invoker.services.model_manager.store @@ -225,7 +261,10 @@ async def del_model_record( "/i/", operation_id="add_model_record", responses={ - 201: {"description": "The model added successfully"}, + 201: { + "description": "The model added successfully", + "content": {"application/json": {"example": example_model_output}}, + }, 409: {"description": "There is already a model corresponding to this path or repo_id"}, 415: {"description": "Unrecognized file/folder format"}, }, @@ -270,6 +309,7 @@ async def heuristic_import( config: Optional[Dict[str, Any]] = Body( description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ", default=None, + example={"name": "modelT", "description": "antique cars"}, ), access_token: Optional[str] = None, ) -> ModelInstallJob: @@ -497,7 +537,10 @@ async def sync_models_to_config() -> Response: "/convert/{key}", operation_id="convert_model", responses={ - 200: {"description": "Model converted successfully"}, + 200: { + "description": "Model converted successfully", + "content": {"application/json": {"example": example_model_output}}, + }, 400: {"description": "Bad request"}, 404: {"description": "Model not found"}, 409: {"description": "There is already a model registered at this location"}, @@ -571,6 +614,15 @@ async def convert_model( @model_manager_v2_router.put( "/merge", operation_id="merge", + responses={ + 200: { + "description": "Model converted successfully", + "content": {"application/json": {"example": example_model_output}}, + }, + 400: {"description": "Bad request"}, + 404: {"description": "Model not found"}, + 409: {"description": "There is already a model registered at this location"}, + }, ) async def merge( keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3), @@ -596,7 +648,6 @@ async def merge( interp: Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] merge_dest_directory: Specify a directory to store the merged model in [models directory] """ - print(f"here i am, keys={keys}") logger = ApiDependencies.invoker.services.logger try: logger.info(f"Merging models: {keys} into {merge_dest_directory or ''}/{merged_model_name}") diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index dd7d9190c4..6fc232b797 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -92,10 +92,10 @@ class IPAdapterInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. - ip_adapter_info = context.services.model_records.get_model(self.ip_adapter_model.key) + ip_adapter_info = context.services.model_manager.store.get_model(self.ip_adapter_model.key) image_encoder_model_id = ip_adapter_info.image_encoder_model_id image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() - image_encoder_models = context.services.model_records.search_by_attr( + image_encoder_models = context.services.model_manager.store.search_by_attr( model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision ) assert len(image_encoder_models) == 1 diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 739cd02374..a3aaf4c9e1 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -106,7 +106,7 @@ class MainModelLoaderInvocation(BaseInvocation): key = self.model.key # TODO: not found exceptions - if not context.services.model_records.exists(key): + if not context.services.model_manager.store.exists(key): raise Exception(f"Unknown model {key}") return ModelLoaderOutput( @@ -175,7 +175,7 @@ class LoraLoaderInvocation(BaseInvocation): lora_key = self.lora.key - if not context.services.model_records.exists(lora_key): + if not context.services.model_manager.store.exists(lora_key): raise Exception(f"Unkown lora: {lora_key}!") if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): @@ -255,7 +255,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation): lora_key = self.lora.key - if not context.services.model_records.exists(lora_key): + if not context.services.model_manager.store.exists(lora_key): raise Exception(f"Unknown lora: {lora_key}!") if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): @@ -321,7 +321,7 @@ class VaeLoaderInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> VAEOutput: key = self.vae_model.key - if not context.services.model_records.exists(key): + if not context.services.model_manager.store.exists(key): raise Exception(f"Unkown vae: {key}!") return VAEOutput(vae=VaeField(vae=ModelInfo(key=key))) diff --git a/invokeai/app/services/config/config_base.py b/invokeai/app/services/config/config_base.py index a304b38a95..983df6b468 100644 --- a/invokeai/app/services/config/config_base.py +++ b/invokeai/app/services/config/config_base.py @@ -27,11 +27,11 @@ class InvokeAISettings(BaseSettings): """Runtime configuration settings in which default values are read from an omegaconf .yaml file.""" initconf: ClassVar[Optional[DictConfig]] = None - argparse_groups: ClassVar[Dict] = {} + argparse_groups: ClassVar[Dict[str, Any]] = {} model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True) - def parse_args(self, argv: Optional[list] = sys.argv[1:]): + def parse_args(self, argv: Optional[List[str]] = sys.argv[1:]) -> None: """Call to parse command-line arguments.""" parser = self.get_parser() opt, unknown_opts = parser.parse_known_args(argv) @@ -68,7 +68,7 @@ class InvokeAISettings(BaseSettings): return OmegaConf.to_yaml(conf) @classmethod - def add_parser_arguments(cls, parser): + def add_parser_arguments(cls, parser) -> None: """Dynamically create arguments for a settings parser.""" if "type" in get_type_hints(cls): settings_stanza = get_args(get_type_hints(cls)["type"])[0] @@ -117,7 +117,8 @@ class InvokeAISettings(BaseSettings): """Return the category of a setting.""" hints = get_type_hints(cls) if command_field in hints: - return get_args(hints[command_field])[0] + result: str = get_args(hints[command_field])[0] + return result else: return "Uncategorized" @@ -158,7 +159,7 @@ class InvokeAISettings(BaseSettings): ] @classmethod - def add_field_argument(cls, command_parser, name: str, field, default_override=None): + def add_field_argument(cls, command_parser, name: str, field, default_override=None) -> None: """Add the argparse arguments for a setting parser.""" field_type = get_type_hints(cls).get(name) default = ( diff --git a/invokeai/app/services/config/config_common.py b/invokeai/app/services/config/config_common.py index d11bcabcf9..27a0f859c2 100644 --- a/invokeai/app/services/config/config_common.py +++ b/invokeai/app/services/config/config_common.py @@ -21,7 +21,7 @@ class PagingArgumentParser(argparse.ArgumentParser): It also supports reading defaults from an init file. """ - def print_help(self, file=None): + def print_help(self, file=None) -> None: text = self.format_help() pydoc.pager(text) diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index f740c50087..7008f8ed74 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -8,12 +8,12 @@ import time import traceback from pathlib import Path from queue import Empty, PriorityQueue -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set import requests from pydantic.networks import AnyHttpUrl from requests import HTTPError -from tqdm import tqdm +from tqdm import tqdm, std from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.util.misc import get_iso_timestamp @@ -49,12 +49,12 @@ class DownloadQueueService(DownloadQueueServiceBase): :param max_parallel_dl: Number of simultaneous downloads allowed [5]. :param requests_session: Optional requests.sessions.Session object, for unit tests. """ - self._jobs = {} + self._jobs: Dict[int, DownloadJob] = {} self._next_job_id = 0 - self._queue = PriorityQueue() + self._queue: PriorityQueue[DownloadJob] = PriorityQueue() self._stop_event = threading.Event() self._job_completed_event = threading.Event() - self._worker_pool = set() + self._worker_pool: Set[threading.Thread] = set() self._lock = threading.Lock() self._logger = InvokeAILogger.get_logger("DownloadQueueService") self._event_bus = event_bus @@ -424,7 +424,7 @@ class DownloadQueueService(DownloadQueueServiceBase): class TqdmProgress(object): """TQDM-based progress bar object to use in on_progress handlers.""" - _bars: Dict[int, tqdm] # the tqdm object + _bars: Dict[int, tqdm] # type: ignore _last: Dict[int, int] # last bytes downloaded def __init__(self) -> None: # noqa D107 diff --git a/invokeai/app/util/misc.py b/invokeai/app/util/misc.py index 910b05d8dd..da431929db 100644 --- a/invokeai/app/util/misc.py +++ b/invokeai/app/util/misc.py @@ -5,7 +5,7 @@ import uuid import numpy as np -def get_timestamp(): +def get_timestamp() -> int: return int(datetime.datetime.now(datetime.timezone.utc).timestamp()) @@ -20,16 +20,16 @@ def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime: SEED_MAX = np.iinfo(np.uint32).max -def get_random_seed(): +def get_random_seed() -> int: rng = np.random.default_rng(seed=None) return int(rng.integers(0, SEED_MAX)) -def uuid_string(): +def uuid_string() -> str: res = uuid.uuid4() return str(res) -def is_optional(value: typing.Any): +def is_optional(value: typing.Any) -> bool: """Checks if a value is typed as Optional. Note that Optional is sugar for Union[x, None].""" return typing.get_origin(value) is typing.Union and type(None) in typing.get_args(value) diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 5f392ada75..7649dee762 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -22,6 +22,7 @@ from invokeai.backend.model_manager.config import ( AnyModel, AnyModelConfig, BaseModelType, + ModelConfigBase, ModelFormat, ModelType, SubModelType, @@ -70,7 +71,7 @@ class ModelLoaderBase(ABC): pass @abstractmethod - def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ Return a model given its confguration. @@ -122,7 +123,7 @@ class AnyModelLoader: """Return the convert cache associated used by the loaders.""" return self._convert_cache - def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ Return a model given its configuration. @@ -144,8 +145,8 @@ class AnyModelLoader: @classmethod def get_implementation( - cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] - ) -> Tuple[Type[ModelLoaderBase], AnyModelConfig, Optional[SubModelType]]: + cls, config: ModelConfigBase, submodel_type: Optional[SubModelType] + ) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]: """Get subclass of ModelLoaderBase registered to handle base and type.""" # We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type) @@ -161,8 +162,8 @@ class AnyModelLoader: @classmethod def _handle_subtype_overrides( - cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] - ) -> Tuple[AnyModelConfig, Optional[SubModelType]]: + cls, config: ModelConfigBase, submodel_type: Optional[SubModelType] + ) -> Tuple[ModelConfigBase, Optional[SubModelType]]: if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None: model_path = Path(config.vae) config_class = ( diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 98d6f34cea..786396062c 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -34,8 +34,8 @@ from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.logging import InvokeAILogger -from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase -from .model_locker import ModelLocker, ModelLockerBase +from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase +from .model_locker import ModelLocker if choose_torch_device() == torch.device("mps"): from torch import mps diff --git a/invokeai/backend/model_manager/metadata/fetch/fetch_base.py b/invokeai/backend/model_manager/metadata/fetch/fetch_base.py index d628ab5c17..5d75493b92 100644 --- a/invokeai/backend/model_manager/metadata/fetch/fetch_base.py +++ b/invokeai/backend/model_manager/metadata/fetch/fetch_base.py @@ -20,7 +20,7 @@ from requests.sessions import Session from invokeai.backend.model_manager import ModelRepoVariant -from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator +from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator, BaseMetadata class ModelMetadataFetchBase(ABC): @@ -62,5 +62,5 @@ class ModelMetadataFetchBase(ABC): @classmethod def from_json(cls, json: str) -> AnyModelRepoMetadata: """Given the JSON representation of the metadata, return the corresponding Pydantic object.""" - metadata = AnyModelRepoMetadataValidator.validate_json(json) + metadata: BaseMetadata = AnyModelRepoMetadataValidator.validate_json(json) # type: ignore return metadata diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index e7d21c578f..2c2066d7c5 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -166,7 +166,7 @@ class ModelProbe(object): fields["original_hash"] = fields.get("original_hash") or hash fields["current_hash"] = fields.get("current_hash") or hash - if format_type == ModelFormat.Diffusers: + if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"): fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() # additional fields needed for main and controlnet models diff --git a/invokeai/backend/model_manager/search.py b/invokeai/backend/model_manager/search.py index f7e1e1bed7..0ead22b743 100644 --- a/invokeai/backend/model_manager/search.py +++ b/invokeai/backend/model_manager/search.py @@ -116,9 +116,9 @@ class ModelSearch(ModelSearchBase): # returns all models that have 'anime' in the path """ - models_found: Set[Path] = Field(default=None) - scanned_dirs: Set[Path] = Field(default=None) - pruned_paths: Set[Path] = Field(default=None) + models_found: Optional[Set[Path]] = Field(default=None) + scanned_dirs: Optional[Set[Path]] = Field(default=None) + pruned_paths: Optional[Set[Path]] = Field(default=None) def search_started(self) -> None: self.models_found = set()