fix a number of typechecking errors

This commit is contained in:
Lincoln Stein 2024-02-13 00:26:49 -05:00 committed by psychedelicious
parent 0845a0ed84
commit 631f6cae19
13 changed files with 101 additions and 48 deletions

View File

@ -36,7 +36,7 @@ async def list_downloads() -> List[DownloadJob]:
400: {"description": "Bad request"},
},
)
async def prune_downloads():
async def prune_downloads() -> Response:
"""Prune completed and errored jobs."""
queue = ApiDependencies.invoker.services.download_queue
queue.prune_jobs()
@ -87,7 +87,7 @@ async def get_download_job(
)
async def cancel_download_job(
id: int = Path(description="ID of the download job to cancel."),
):
) -> Response:
"""Cancel a download job using its ID."""
try:
queue = ApiDependencies.invoker.services.download_queue
@ -105,7 +105,7 @@ async def cancel_download_job(
204: {"description": "Download jobs have been cancelled"},
},
)
async def cancel_all_download_jobs():
async def cancel_all_download_jobs() -> Response:
"""Cancel all download jobs."""
ApiDependencies.invoker.services.download_queue.cancel_all_jobs()
return Response(status_code=204)

View File

@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Set
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
from starlette.exceptions import HTTPException
from typing_extensions import Annotated
@ -37,6 +37,35 @@ from ..dependencies import ApiDependencies
model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"])
example_model_output = {
"path": "sd-1/main/openjourney",
"name": "openjourney",
"base": "sd-1",
"type": "main",
"format": "diffusers",
"key": "3a0e45ff858926fd4a63da630688b1e1",
"original_hash": "1c12f18fb6e403baef26fb9d720fbd2f",
"current_hash": "1c12f18fb6e403baef26fb9d720fbd2f",
"description": "sd-1 main model openjourney",
"source": "/opt/invokeai/models/sd-1/main/openjourney",
"last_modified": 1707794711,
"vae": "/opt/invokeai/models/sd-1/vae/vae-ft-mse-840000-ema-pruned_fp16.safetensors",
"variant": "normal",
"prediction_type": "epsilon",
"repo_variant": "fp16",
}
example_model_input = {
"path": "base/type/name",
"name": "model_name",
"base": "sd-1",
"type": "main",
"format": "diffusers",
"description": "Model description",
"vae": None,
"variant": "normal",
}
class ModelsList(BaseModel):
"""Return list of configs."""
@ -88,7 +117,10 @@ async def list_model_records(
"/i/{key}",
operation_id="get_model_record",
responses={
200: {"description": "Success"},
200: {
"description": "The model configuration was retrieved successfully",
"content": {"application/json": {"example": example_model_output}},
},
400: {"description": "Bad request"},
404: {"description": "The model could not be found"},
},
@ -165,18 +197,22 @@ async def search_by_metadata_tags(
"/i/{key}",
operation_id="update_model_record",
responses={
200: {"description": "The model was updated successfully"},
200: {
"description": "The model was updated successfully",
"content": {"application/json": {"example": example_model_output}},
},
400: {"description": "Bad request"},
404: {"description": "The model could not be found"},
409: {"description": "There is already a model corresponding to the new name"},
},
status_code=200,
response_model=AnyModelConfig,
)
async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")],
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
) -> AnyModelConfig:
info: Annotated[
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
],
) -> Annotated[AnyModelConfig, Field(example="this is neat")]:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store
@ -225,7 +261,10 @@ async def del_model_record(
"/i/",
operation_id="add_model_record",
responses={
201: {"description": "The model added successfully"},
201: {
"description": "The model added successfully",
"content": {"application/json": {"example": example_model_output}},
},
409: {"description": "There is already a model corresponding to this path or repo_id"},
415: {"description": "Unrecognized file/folder format"},
},
@ -270,6 +309,7 @@ async def heuristic_import(
config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
default=None,
example={"name": "modelT", "description": "antique cars"},
),
access_token: Optional[str] = None,
) -> ModelInstallJob:
@ -497,7 +537,10 @@ async def sync_models_to_config() -> Response:
"/convert/{key}",
operation_id="convert_model",
responses={
200: {"description": "Model converted successfully"},
200: {
"description": "Model converted successfully",
"content": {"application/json": {"example": example_model_output}},
},
400: {"description": "Bad request"},
404: {"description": "Model not found"},
409: {"description": "There is already a model registered at this location"},
@ -571,6 +614,15 @@ async def convert_model(
@model_manager_v2_router.put(
"/merge",
operation_id="merge",
responses={
200: {
"description": "Model converted successfully",
"content": {"application/json": {"example": example_model_output}},
},
400: {"description": "Bad request"},
404: {"description": "Model not found"},
409: {"description": "There is already a model registered at this location"},
},
)
async def merge(
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
@ -596,7 +648,6 @@ async def merge(
interp: Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
merge_dest_directory: Specify a directory to store the merged model in [models directory]
"""
print(f"here i am, keys={keys}")
logger = ApiDependencies.invoker.services.logger
try:
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")

View File

@ -90,10 +90,10 @@ class IPAdapterInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.services.model_records.get_model(self.ip_adapter_model.key)
ip_adapter_info = context.services.model_manager.store.get_model(self.ip_adapter_model.key)
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
image_encoder_models = context.services.model_records.search_by_attr(
image_encoder_models = context.services.model_manager.store.search_by_attr(
model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision
)
assert len(image_encoder_models) == 1

View File

@ -103,7 +103,7 @@ class MainModelLoaderInvocation(BaseInvocation):
key = self.model.key
# TODO: not found exceptions
if not context.services.model_records.exists(key):
if not context.services.model_manager.store.exists(key):
raise Exception(f"Unknown model {key}")
return ModelLoaderOutput(
@ -172,7 +172,7 @@ class LoraLoaderInvocation(BaseInvocation):
lora_key = self.lora.key
if not context.services.model_records.exists(lora_key):
if not context.services.model_manager.store.exists(lora_key):
raise Exception(f"Unkown lora: {lora_key}!")
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
@ -252,7 +252,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
lora_key = self.lora.key
if not context.services.model_records.exists(lora_key):
if not context.services.model_manager.store.exists(lora_key):
raise Exception(f"Unknown lora: {lora_key}!")
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
@ -318,7 +318,7 @@ class VaeLoaderInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> VAEOutput:
key = self.vae_model.key
if not context.services.model_records.exists(key):
if not context.services.model_manager.store.exists(key):
raise Exception(f"Unkown vae: {key}!")
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))

View File

@ -27,11 +27,11 @@ class InvokeAISettings(BaseSettings):
"""Runtime configuration settings in which default values are read from an omegaconf .yaml file."""
initconf: ClassVar[Optional[DictConfig]] = None
argparse_groups: ClassVar[Dict] = {}
argparse_groups: ClassVar[Dict[str, Any]] = {}
model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True)
def parse_args(self, argv: Optional[list] = sys.argv[1:]):
def parse_args(self, argv: Optional[List[str]] = sys.argv[1:]) -> None:
"""Call to parse command-line arguments."""
parser = self.get_parser()
opt, unknown_opts = parser.parse_known_args(argv)
@ -68,7 +68,7 @@ class InvokeAISettings(BaseSettings):
return OmegaConf.to_yaml(conf)
@classmethod
def add_parser_arguments(cls, parser):
def add_parser_arguments(cls, parser) -> None:
"""Dynamically create arguments for a settings parser."""
if "type" in get_type_hints(cls):
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
@ -117,7 +117,8 @@ class InvokeAISettings(BaseSettings):
"""Return the category of a setting."""
hints = get_type_hints(cls)
if command_field in hints:
return get_args(hints[command_field])[0]
result: str = get_args(hints[command_field])[0]
return result
else:
return "Uncategorized"
@ -158,7 +159,7 @@ class InvokeAISettings(BaseSettings):
]
@classmethod
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
def add_field_argument(cls, command_parser, name: str, field, default_override=None) -> None:
"""Add the argparse arguments for a setting parser."""
field_type = get_type_hints(cls).get(name)
default = (

View File

@ -21,7 +21,7 @@ class PagingArgumentParser(argparse.ArgumentParser):
It also supports reading defaults from an init file.
"""
def print_help(self, file=None):
def print_help(self, file=None) -> None:
text = self.format_help()
pydoc.pager(text)

View File

@ -8,12 +8,12 @@ import time
import traceback
from pathlib import Path
from queue import Empty, PriorityQueue
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Set
import requests
from pydantic.networks import AnyHttpUrl
from requests import HTTPError
from tqdm import tqdm
from tqdm import tqdm, std
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.util.misc import get_iso_timestamp
@ -49,12 +49,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
:param requests_session: Optional requests.sessions.Session object, for unit tests.
"""
self._jobs = {}
self._jobs: Dict[int, DownloadJob] = {}
self._next_job_id = 0
self._queue = PriorityQueue()
self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
self._stop_event = threading.Event()
self._job_completed_event = threading.Event()
self._worker_pool = set()
self._worker_pool: Set[threading.Thread] = set()
self._lock = threading.Lock()
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
self._event_bus = event_bus
@ -424,7 +424,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
class TqdmProgress(object):
"""TQDM-based progress bar object to use in on_progress handlers."""
_bars: Dict[int, tqdm] # the tqdm object
_bars: Dict[int, tqdm] # type: ignore
_last: Dict[int, int] # last bytes downloaded
def __init__(self) -> None: # noqa D107

View File

@ -5,7 +5,7 @@ import uuid
import numpy as np
def get_timestamp():
def get_timestamp() -> int:
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
@ -20,16 +20,16 @@ def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:
SEED_MAX = np.iinfo(np.uint32).max
def get_random_seed():
def get_random_seed() -> int:
rng = np.random.default_rng(seed=None)
return int(rng.integers(0, SEED_MAX))
def uuid_string():
def uuid_string() -> str:
res = uuid.uuid4()
return str(res)
def is_optional(value: typing.Any):
def is_optional(value: typing.Any) -> bool:
"""Checks if a value is typed as Optional. Note that Optional is sugar for Union[x, None]."""
return typing.get_origin(value) is typing.Union and type(None) in typing.get_args(value)

View File

@ -22,6 +22,7 @@ from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelConfigBase,
ModelFormat,
ModelType,
SubModelType,
@ -70,7 +71,7 @@ class ModelLoaderBase(ABC):
pass
@abstractmethod
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
Return a model given its confguration.
@ -122,7 +123,7 @@ class AnyModelLoader:
"""Return the convert cache associated used by the loaders."""
return self._convert_cache
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
Return a model given its configuration.
@ -144,8 +145,8 @@ class AnyModelLoader:
@classmethod
def get_implementation(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], AnyModelConfig, Optional[SubModelType]]:
cls, config: ModelConfigBase, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
"""Get subclass of ModelLoaderBase registered to handle base and type."""
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
@ -161,8 +162,8 @@ class AnyModelLoader:
@classmethod
def _handle_subtype_overrides(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[AnyModelConfig, Optional[SubModelType]]:
cls, config: ModelConfigBase, submodel_type: Optional[SubModelType]
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
model_path = Path(config.vae)
config_class = (

View File

@ -34,8 +34,8 @@ from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot,
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase
from .model_locker import ModelLocker, ModelLockerBase
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
from .model_locker import ModelLocker
if choose_torch_device() == torch.device("mps"):
from torch import mps

View File

@ -20,7 +20,7 @@ from requests.sessions import Session
from invokeai.backend.model_manager import ModelRepoVariant
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator, BaseMetadata
class ModelMetadataFetchBase(ABC):
@ -62,5 +62,5 @@ class ModelMetadataFetchBase(ABC):
@classmethod
def from_json(cls, json: str) -> AnyModelRepoMetadata:
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
metadata = AnyModelRepoMetadataValidator.validate_json(json)
metadata: BaseMetadata = AnyModelRepoMetadataValidator.validate_json(json) # type: ignore
return metadata

View File

@ -166,7 +166,7 @@ class ModelProbe(object):
fields["original_hash"] = fields.get("original_hash") or hash
fields["current_hash"] = fields.get("current_hash") or hash
if format_type == ModelFormat.Diffusers:
if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
# additional fields needed for main and controlnet models

View File

@ -116,9 +116,9 @@ class ModelSearch(ModelSearchBase):
# returns all models that have 'anime' in the path
"""
models_found: Set[Path] = Field(default=None)
scanned_dirs: Set[Path] = Field(default=None)
pruned_paths: Set[Path] = Field(default=None)
models_found: Optional[Set[Path]] = Field(default=None)
scanned_dirs: Optional[Set[Path]] = Field(default=None)
pruned_paths: Optional[Set[Path]] = Field(default=None)
def search_started(self) -> None:
self.models_found = set()