fix a number of typechecking errors

This commit is contained in:
Lincoln Stein 2024-02-13 00:26:49 -05:00
parent 433eb73d8e
commit bd802d1e7a
13 changed files with 101 additions and 48 deletions

View File

@ -36,7 +36,7 @@ async def list_downloads() -> List[DownloadJob]:
400: {"description": "Bad request"}, 400: {"description": "Bad request"},
}, },
) )
async def prune_downloads(): async def prune_downloads() -> Response:
"""Prune completed and errored jobs.""" """Prune completed and errored jobs."""
queue = ApiDependencies.invoker.services.download_queue queue = ApiDependencies.invoker.services.download_queue
queue.prune_jobs() queue.prune_jobs()
@ -87,7 +87,7 @@ async def get_download_job(
) )
async def cancel_download_job( async def cancel_download_job(
id: int = Path(description="ID of the download job to cancel."), id: int = Path(description="ID of the download job to cancel."),
): ) -> Response:
"""Cancel a download job using its ID.""" """Cancel a download job using its ID."""
try: try:
queue = ApiDependencies.invoker.services.download_queue queue = ApiDependencies.invoker.services.download_queue
@ -105,7 +105,7 @@ async def cancel_download_job(
204: {"description": "Download jobs have been cancelled"}, 204: {"description": "Download jobs have been cancelled"},
}, },
) )
async def cancel_all_download_jobs(): async def cancel_all_download_jobs() -> Response:
"""Cancel all download jobs.""" """Cancel all download jobs."""
ApiDependencies.invoker.services.download_queue.cancel_all_jobs() ApiDependencies.invoker.services.download_queue.cancel_all_jobs()
return Response(status_code=204) return Response(status_code=204)

View File

@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Set
from fastapi import Body, Path, Query, Response from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict, Field
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from typing_extensions import Annotated from typing_extensions import Annotated
@ -37,6 +37,35 @@ from ..dependencies import ApiDependencies
model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"]) model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"])
example_model_output = {
"path": "sd-1/main/openjourney",
"name": "openjourney",
"base": "sd-1",
"type": "main",
"format": "diffusers",
"key": "3a0e45ff858926fd4a63da630688b1e1",
"original_hash": "1c12f18fb6e403baef26fb9d720fbd2f",
"current_hash": "1c12f18fb6e403baef26fb9d720fbd2f",
"description": "sd-1 main model openjourney",
"source": "/opt/invokeai/models/sd-1/main/openjourney",
"last_modified": 1707794711,
"vae": "/opt/invokeai/models/sd-1/vae/vae-ft-mse-840000-ema-pruned_fp16.safetensors",
"variant": "normal",
"prediction_type": "epsilon",
"repo_variant": "fp16",
}
example_model_input = {
"path": "base/type/name",
"name": "model_name",
"base": "sd-1",
"type": "main",
"format": "diffusers",
"description": "Model description",
"vae": None,
"variant": "normal",
}
class ModelsList(BaseModel): class ModelsList(BaseModel):
"""Return list of configs.""" """Return list of configs."""
@ -88,7 +117,10 @@ async def list_model_records(
"/i/{key}", "/i/{key}",
operation_id="get_model_record", operation_id="get_model_record",
responses={ responses={
200: {"description": "Success"}, 200: {
"description": "The model configuration was retrieved successfully",
"content": {"application/json": {"example": example_model_output}},
},
400: {"description": "Bad request"}, 400: {"description": "Bad request"},
404: {"description": "The model could not be found"}, 404: {"description": "The model could not be found"},
}, },
@ -165,18 +197,22 @@ async def search_by_metadata_tags(
"/i/{key}", "/i/{key}",
operation_id="update_model_record", operation_id="update_model_record",
responses={ responses={
200: {"description": "The model was updated successfully"}, 200: {
"description": "The model was updated successfully",
"content": {"application/json": {"example": example_model_output}},
},
400: {"description": "Bad request"}, 400: {"description": "Bad request"},
404: {"description": "The model could not be found"}, 404: {"description": "The model could not be found"},
409: {"description": "There is already a model corresponding to the new name"}, 409: {"description": "There is already a model corresponding to the new name"},
}, },
status_code=200, status_code=200,
response_model=AnyModelConfig,
) )
async def update_model_record( async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")], key: Annotated[str, Path(description="Unique key of model")],
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")], info: Annotated[
) -> AnyModelConfig: AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
],
) -> Annotated[AnyModelConfig, Field(example="this is neat")]:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store record_store = ApiDependencies.invoker.services.model_manager.store
@ -225,7 +261,10 @@ async def del_model_record(
"/i/", "/i/",
operation_id="add_model_record", operation_id="add_model_record",
responses={ responses={
201: {"description": "The model added successfully"}, 201: {
"description": "The model added successfully",
"content": {"application/json": {"example": example_model_output}},
},
409: {"description": "There is already a model corresponding to this path or repo_id"}, 409: {"description": "There is already a model corresponding to this path or repo_id"},
415: {"description": "Unrecognized file/folder format"}, 415: {"description": "Unrecognized file/folder format"},
}, },
@ -270,6 +309,7 @@ async def heuristic_import(
config: Optional[Dict[str, Any]] = Body( config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ", description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
default=None, default=None,
example={"name": "modelT", "description": "antique cars"},
), ),
access_token: Optional[str] = None, access_token: Optional[str] = None,
) -> ModelInstallJob: ) -> ModelInstallJob:
@ -497,7 +537,10 @@ async def sync_models_to_config() -> Response:
"/convert/{key}", "/convert/{key}",
operation_id="convert_model", operation_id="convert_model",
responses={ responses={
200: {"description": "Model converted successfully"}, 200: {
"description": "Model converted successfully",
"content": {"application/json": {"example": example_model_output}},
},
400: {"description": "Bad request"}, 400: {"description": "Bad request"},
404: {"description": "Model not found"}, 404: {"description": "Model not found"},
409: {"description": "There is already a model registered at this location"}, 409: {"description": "There is already a model registered at this location"},
@ -571,6 +614,15 @@ async def convert_model(
@model_manager_v2_router.put( @model_manager_v2_router.put(
"/merge", "/merge",
operation_id="merge", operation_id="merge",
responses={
200: {
"description": "Model converted successfully",
"content": {"application/json": {"example": example_model_output}},
},
400: {"description": "Bad request"},
404: {"description": "Model not found"},
409: {"description": "There is already a model registered at this location"},
},
) )
async def merge( async def merge(
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3), keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
@ -596,7 +648,6 @@ async def merge(
interp: Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] interp: Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
merge_dest_directory: Specify a directory to store the merged model in [models directory] merge_dest_directory: Specify a directory to store the merged model in [models directory]
""" """
print(f"here i am, keys={keys}")
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}") logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")

View File

@ -92,10 +92,10 @@ class IPAdapterInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IPAdapterOutput: def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.services.model_records.get_model(self.ip_adapter_model.key) ip_adapter_info = context.services.model_manager.store.get_model(self.ip_adapter_model.key)
image_encoder_model_id = ip_adapter_info.image_encoder_model_id image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
image_encoder_models = context.services.model_records.search_by_attr( image_encoder_models = context.services.model_manager.store.search_by_attr(
model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision
) )
assert len(image_encoder_models) == 1 assert len(image_encoder_models) == 1

View File

@ -106,7 +106,7 @@ class MainModelLoaderInvocation(BaseInvocation):
key = self.model.key key = self.model.key
# TODO: not found exceptions # TODO: not found exceptions
if not context.services.model_records.exists(key): if not context.services.model_manager.store.exists(key):
raise Exception(f"Unknown model {key}") raise Exception(f"Unknown model {key}")
return ModelLoaderOutput( return ModelLoaderOutput(
@ -175,7 +175,7 @@ class LoraLoaderInvocation(BaseInvocation):
lora_key = self.lora.key lora_key = self.lora.key
if not context.services.model_records.exists(lora_key): if not context.services.model_manager.store.exists(lora_key):
raise Exception(f"Unkown lora: {lora_key}!") raise Exception(f"Unkown lora: {lora_key}!")
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
@ -255,7 +255,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
lora_key = self.lora.key lora_key = self.lora.key
if not context.services.model_records.exists(lora_key): if not context.services.model_manager.store.exists(lora_key):
raise Exception(f"Unknown lora: {lora_key}!") raise Exception(f"Unknown lora: {lora_key}!")
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
@ -321,7 +321,7 @@ class VaeLoaderInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> VAEOutput: def invoke(self, context: InvocationContext) -> VAEOutput:
key = self.vae_model.key key = self.vae_model.key
if not context.services.model_records.exists(key): if not context.services.model_manager.store.exists(key):
raise Exception(f"Unkown vae: {key}!") raise Exception(f"Unkown vae: {key}!")
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key))) return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))

View File

@ -27,11 +27,11 @@ class InvokeAISettings(BaseSettings):
"""Runtime configuration settings in which default values are read from an omegaconf .yaml file.""" """Runtime configuration settings in which default values are read from an omegaconf .yaml file."""
initconf: ClassVar[Optional[DictConfig]] = None initconf: ClassVar[Optional[DictConfig]] = None
argparse_groups: ClassVar[Dict] = {} argparse_groups: ClassVar[Dict[str, Any]] = {}
model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True) model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True)
def parse_args(self, argv: Optional[list] = sys.argv[1:]): def parse_args(self, argv: Optional[List[str]] = sys.argv[1:]) -> None:
"""Call to parse command-line arguments.""" """Call to parse command-line arguments."""
parser = self.get_parser() parser = self.get_parser()
opt, unknown_opts = parser.parse_known_args(argv) opt, unknown_opts = parser.parse_known_args(argv)
@ -68,7 +68,7 @@ class InvokeAISettings(BaseSettings):
return OmegaConf.to_yaml(conf) return OmegaConf.to_yaml(conf)
@classmethod @classmethod
def add_parser_arguments(cls, parser): def add_parser_arguments(cls, parser) -> None:
"""Dynamically create arguments for a settings parser.""" """Dynamically create arguments for a settings parser."""
if "type" in get_type_hints(cls): if "type" in get_type_hints(cls):
settings_stanza = get_args(get_type_hints(cls)["type"])[0] settings_stanza = get_args(get_type_hints(cls)["type"])[0]
@ -117,7 +117,8 @@ class InvokeAISettings(BaseSettings):
"""Return the category of a setting.""" """Return the category of a setting."""
hints = get_type_hints(cls) hints = get_type_hints(cls)
if command_field in hints: if command_field in hints:
return get_args(hints[command_field])[0] result: str = get_args(hints[command_field])[0]
return result
else: else:
return "Uncategorized" return "Uncategorized"
@ -158,7 +159,7 @@ class InvokeAISettings(BaseSettings):
] ]
@classmethod @classmethod
def add_field_argument(cls, command_parser, name: str, field, default_override=None): def add_field_argument(cls, command_parser, name: str, field, default_override=None) -> None:
"""Add the argparse arguments for a setting parser.""" """Add the argparse arguments for a setting parser."""
field_type = get_type_hints(cls).get(name) field_type = get_type_hints(cls).get(name)
default = ( default = (

View File

@ -21,7 +21,7 @@ class PagingArgumentParser(argparse.ArgumentParser):
It also supports reading defaults from an init file. It also supports reading defaults from an init file.
""" """
def print_help(self, file=None): def print_help(self, file=None) -> None:
text = self.format_help() text = self.format_help()
pydoc.pager(text) pydoc.pager(text)

View File

@ -8,12 +8,12 @@ import time
import traceback import traceback
from pathlib import Path from pathlib import Path
from queue import Empty, PriorityQueue from queue import Empty, PriorityQueue
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Set
import requests import requests
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from requests import HTTPError from requests import HTTPError
from tqdm import tqdm from tqdm import tqdm, std
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.util.misc import get_iso_timestamp from invokeai.app.util.misc import get_iso_timestamp
@ -49,12 +49,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
:param max_parallel_dl: Number of simultaneous downloads allowed [5]. :param max_parallel_dl: Number of simultaneous downloads allowed [5].
:param requests_session: Optional requests.sessions.Session object, for unit tests. :param requests_session: Optional requests.sessions.Session object, for unit tests.
""" """
self._jobs = {} self._jobs: Dict[int, DownloadJob] = {}
self._next_job_id = 0 self._next_job_id = 0
self._queue = PriorityQueue() self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
self._stop_event = threading.Event() self._stop_event = threading.Event()
self._job_completed_event = threading.Event() self._job_completed_event = threading.Event()
self._worker_pool = set() self._worker_pool: Set[threading.Thread] = set()
self._lock = threading.Lock() self._lock = threading.Lock()
self._logger = InvokeAILogger.get_logger("DownloadQueueService") self._logger = InvokeAILogger.get_logger("DownloadQueueService")
self._event_bus = event_bus self._event_bus = event_bus
@ -424,7 +424,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
class TqdmProgress(object): class TqdmProgress(object):
"""TQDM-based progress bar object to use in on_progress handlers.""" """TQDM-based progress bar object to use in on_progress handlers."""
_bars: Dict[int, tqdm] # the tqdm object _bars: Dict[int, tqdm] # type: ignore
_last: Dict[int, int] # last bytes downloaded _last: Dict[int, int] # last bytes downloaded
def __init__(self) -> None: # noqa D107 def __init__(self) -> None: # noqa D107

View File

@ -5,7 +5,7 @@ import uuid
import numpy as np import numpy as np
def get_timestamp(): def get_timestamp() -> int:
return int(datetime.datetime.now(datetime.timezone.utc).timestamp()) return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
@ -20,16 +20,16 @@ def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:
SEED_MAX = np.iinfo(np.uint32).max SEED_MAX = np.iinfo(np.uint32).max
def get_random_seed(): def get_random_seed() -> int:
rng = np.random.default_rng(seed=None) rng = np.random.default_rng(seed=None)
return int(rng.integers(0, SEED_MAX)) return int(rng.integers(0, SEED_MAX))
def uuid_string(): def uuid_string() -> str:
res = uuid.uuid4() res = uuid.uuid4()
return str(res) return str(res)
def is_optional(value: typing.Any): def is_optional(value: typing.Any) -> bool:
"""Checks if a value is typed as Optional. Note that Optional is sugar for Union[x, None].""" """Checks if a value is typed as Optional. Note that Optional is sugar for Union[x, None]."""
return typing.get_origin(value) is typing.Union and type(None) in typing.get_args(value) return typing.get_origin(value) is typing.Union and type(None) in typing.get_args(value)

View File

@ -22,6 +22,7 @@ from invokeai.backend.model_manager.config import (
AnyModel, AnyModel,
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
ModelConfigBase,
ModelFormat, ModelFormat,
ModelType, ModelType,
SubModelType, SubModelType,
@ -70,7 +71,7 @@ class ModelLoaderBase(ABC):
pass pass
@abstractmethod @abstractmethod
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
""" """
Return a model given its confguration. Return a model given its confguration.
@ -122,7 +123,7 @@ class AnyModelLoader:
"""Return the convert cache associated used by the loaders.""" """Return the convert cache associated used by the loaders."""
return self._convert_cache return self._convert_cache
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
""" """
Return a model given its configuration. Return a model given its configuration.
@ -144,8 +145,8 @@ class AnyModelLoader:
@classmethod @classmethod
def get_implementation( def get_implementation(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] cls, config: ModelConfigBase, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], AnyModelConfig, Optional[SubModelType]]: ) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
"""Get subclass of ModelLoaderBase registered to handle base and type.""" """Get subclass of ModelLoaderBase registered to handle base and type."""
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned # We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type) conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
@ -161,8 +162,8 @@ class AnyModelLoader:
@classmethod @classmethod
def _handle_subtype_overrides( def _handle_subtype_overrides(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] cls, config: ModelConfigBase, submodel_type: Optional[SubModelType]
) -> Tuple[AnyModelConfig, Optional[SubModelType]]: ) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None: if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
model_path = Path(config.vae) model_path = Path(config.vae)
config_class = ( config_class = (

View File

@ -34,8 +34,8 @@ from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot,
from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
from .model_locker import ModelLocker, ModelLockerBase from .model_locker import ModelLocker
if choose_torch_device() == torch.device("mps"): if choose_torch_device() == torch.device("mps"):
from torch import mps from torch import mps

View File

@ -20,7 +20,7 @@ from requests.sessions import Session
from invokeai.backend.model_manager import ModelRepoVariant from invokeai.backend.model_manager import ModelRepoVariant
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator, BaseMetadata
class ModelMetadataFetchBase(ABC): class ModelMetadataFetchBase(ABC):
@ -62,5 +62,5 @@ class ModelMetadataFetchBase(ABC):
@classmethod @classmethod
def from_json(cls, json: str) -> AnyModelRepoMetadata: def from_json(cls, json: str) -> AnyModelRepoMetadata:
"""Given the JSON representation of the metadata, return the corresponding Pydantic object.""" """Given the JSON representation of the metadata, return the corresponding Pydantic object."""
metadata = AnyModelRepoMetadataValidator.validate_json(json) metadata: BaseMetadata = AnyModelRepoMetadataValidator.validate_json(json) # type: ignore
return metadata return metadata

View File

@ -166,7 +166,7 @@ class ModelProbe(object):
fields["original_hash"] = fields.get("original_hash") or hash fields["original_hash"] = fields.get("original_hash") or hash
fields["current_hash"] = fields.get("current_hash") or hash fields["current_hash"] = fields.get("current_hash") or hash
if format_type == ModelFormat.Diffusers: if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
# additional fields needed for main and controlnet models # additional fields needed for main and controlnet models

View File

@ -116,9 +116,9 @@ class ModelSearch(ModelSearchBase):
# returns all models that have 'anime' in the path # returns all models that have 'anime' in the path
""" """
models_found: Set[Path] = Field(default=None) models_found: Optional[Set[Path]] = Field(default=None)
scanned_dirs: Set[Path] = Field(default=None) scanned_dirs: Optional[Set[Path]] = Field(default=None)
pruned_paths: Set[Path] = Field(default=None) pruned_paths: Optional[Set[Path]] = Field(default=None)
def search_started(self) -> None: def search_started(self) -> None:
self.models_found = set() self.models_found = set()