mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix a number of typechecking errors
This commit is contained in:
parent
0845a0ed84
commit
631f6cae19
@ -36,7 +36,7 @@ async def list_downloads() -> List[DownloadJob]:
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def prune_downloads():
|
||||
async def prune_downloads() -> Response:
|
||||
"""Prune completed and errored jobs."""
|
||||
queue = ApiDependencies.invoker.services.download_queue
|
||||
queue.prune_jobs()
|
||||
@ -87,7 +87,7 @@ async def get_download_job(
|
||||
)
|
||||
async def cancel_download_job(
|
||||
id: int = Path(description="ID of the download job to cancel."),
|
||||
):
|
||||
) -> Response:
|
||||
"""Cancel a download job using its ID."""
|
||||
try:
|
||||
queue = ApiDependencies.invoker.services.download_queue
|
||||
@ -105,7 +105,7 @@ async def cancel_download_job(
|
||||
204: {"description": "Download jobs have been cancelled"},
|
||||
},
|
||||
)
|
||||
async def cancel_all_download_jobs():
|
||||
async def cancel_all_download_jobs() -> Response:
|
||||
"""Cancel all download jobs."""
|
||||
ApiDependencies.invoker.services.download_queue.cancel_all_jobs()
|
||||
return Response(status_code=204)
|
||||
|
@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
@ -37,6 +37,35 @@ from ..dependencies import ApiDependencies
|
||||
|
||||
model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"])
|
||||
|
||||
example_model_output = {
|
||||
"path": "sd-1/main/openjourney",
|
||||
"name": "openjourney",
|
||||
"base": "sd-1",
|
||||
"type": "main",
|
||||
"format": "diffusers",
|
||||
"key": "3a0e45ff858926fd4a63da630688b1e1",
|
||||
"original_hash": "1c12f18fb6e403baef26fb9d720fbd2f",
|
||||
"current_hash": "1c12f18fb6e403baef26fb9d720fbd2f",
|
||||
"description": "sd-1 main model openjourney",
|
||||
"source": "/opt/invokeai/models/sd-1/main/openjourney",
|
||||
"last_modified": 1707794711,
|
||||
"vae": "/opt/invokeai/models/sd-1/vae/vae-ft-mse-840000-ema-pruned_fp16.safetensors",
|
||||
"variant": "normal",
|
||||
"prediction_type": "epsilon",
|
||||
"repo_variant": "fp16",
|
||||
}
|
||||
|
||||
example_model_input = {
|
||||
"path": "base/type/name",
|
||||
"name": "model_name",
|
||||
"base": "sd-1",
|
||||
"type": "main",
|
||||
"format": "diffusers",
|
||||
"description": "Model description",
|
||||
"vae": None,
|
||||
"variant": "normal",
|
||||
}
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
"""Return list of configs."""
|
||||
@ -88,7 +117,10 @@ async def list_model_records(
|
||||
"/i/{key}",
|
||||
operation_id="get_model_record",
|
||||
responses={
|
||||
200: {"description": "Success"},
|
||||
200: {
|
||||
"description": "The model configuration was retrieved successfully",
|
||||
"content": {"application/json": {"example": example_model_output}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The model could not be found"},
|
||||
},
|
||||
@ -165,18 +197,22 @@ async def search_by_metadata_tags(
|
||||
"/i/{key}",
|
||||
operation_id="update_model_record",
|
||||
responses={
|
||||
200: {"description": "The model was updated successfully"},
|
||||
200: {
|
||||
"description": "The model was updated successfully",
|
||||
"content": {"application/json": {"example": example_model_output}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The model could not be found"},
|
||||
409: {"description": "There is already a model corresponding to the new name"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=AnyModelConfig,
|
||||
)
|
||||
async def update_model_record(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
|
||||
) -> AnyModelConfig:
|
||||
info: Annotated[
|
||||
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
],
|
||||
) -> Annotated[AnyModelConfig, Field(example="this is neat")]:
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
@ -225,7 +261,10 @@ async def del_model_record(
|
||||
"/i/",
|
||||
operation_id="add_model_record",
|
||||
responses={
|
||||
201: {"description": "The model added successfully"},
|
||||
201: {
|
||||
"description": "The model added successfully",
|
||||
"content": {"application/json": {"example": example_model_output}},
|
||||
},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
},
|
||||
@ -270,6 +309,7 @@ async def heuristic_import(
|
||||
config: Optional[Dict[str, Any]] = Body(
|
||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
default=None,
|
||||
example={"name": "modelT", "description": "antique cars"},
|
||||
),
|
||||
access_token: Optional[str] = None,
|
||||
) -> ModelInstallJob:
|
||||
@ -497,7 +537,10 @@ async def sync_models_to_config() -> Response:
|
||||
"/convert/{key}",
|
||||
operation_id="convert_model",
|
||||
responses={
|
||||
200: {"description": "Model converted successfully"},
|
||||
200: {
|
||||
"description": "Model converted successfully",
|
||||
"content": {"application/json": {"example": example_model_output}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "Model not found"},
|
||||
409: {"description": "There is already a model registered at this location"},
|
||||
@ -571,6 +614,15 @@ async def convert_model(
|
||||
@model_manager_v2_router.put(
|
||||
"/merge",
|
||||
operation_id="merge",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Model converted successfully",
|
||||
"content": {"application/json": {"example": example_model_output}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "Model not found"},
|
||||
409: {"description": "There is already a model registered at this location"},
|
||||
},
|
||||
)
|
||||
async def merge(
|
||||
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
||||
@ -596,7 +648,6 @@ async def merge(
|
||||
interp: Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
|
||||
merge_dest_directory: Specify a directory to store the merged model in [models directory]
|
||||
"""
|
||||
print(f"here i am, keys={keys}")
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
|
@ -90,10 +90,10 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||
ip_adapter_info = context.services.model_records.get_model(self.ip_adapter_model.key)
|
||||
ip_adapter_info = context.services.model_manager.store.get_model(self.ip_adapter_model.key)
|
||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||
image_encoder_models = context.services.model_records.search_by_attr(
|
||||
image_encoder_models = context.services.model_manager.store.search_by_attr(
|
||||
model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision
|
||||
)
|
||||
assert len(image_encoder_models) == 1
|
||||
|
@ -103,7 +103,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
key = self.model.key
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_records.exists(key):
|
||||
if not context.services.model_manager.store.exists(key):
|
||||
raise Exception(f"Unknown model {key}")
|
||||
|
||||
return ModelLoaderOutput(
|
||||
@ -172,7 +172,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.services.model_records.exists(lora_key):
|
||||
if not context.services.model_manager.store.exists(lora_key):
|
||||
raise Exception(f"Unkown lora: {lora_key}!")
|
||||
|
||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||
@ -252,7 +252,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.services.model_records.exists(lora_key):
|
||||
if not context.services.model_manager.store.exists(lora_key):
|
||||
raise Exception(f"Unknown lora: {lora_key}!")
|
||||
|
||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||
@ -318,7 +318,7 @@ class VaeLoaderInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> VAEOutput:
|
||||
key = self.vae_model.key
|
||||
|
||||
if not context.services.model_records.exists(key):
|
||||
if not context.services.model_manager.store.exists(key):
|
||||
raise Exception(f"Unkown vae: {key}!")
|
||||
|
||||
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))
|
||||
|
@ -27,11 +27,11 @@ class InvokeAISettings(BaseSettings):
|
||||
"""Runtime configuration settings in which default values are read from an omegaconf .yaml file."""
|
||||
|
||||
initconf: ClassVar[Optional[DictConfig]] = None
|
||||
argparse_groups: ClassVar[Dict] = {}
|
||||
argparse_groups: ClassVar[Dict[str, Any]] = {}
|
||||
|
||||
model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True)
|
||||
|
||||
def parse_args(self, argv: Optional[list] = sys.argv[1:]):
|
||||
def parse_args(self, argv: Optional[List[str]] = sys.argv[1:]) -> None:
|
||||
"""Call to parse command-line arguments."""
|
||||
parser = self.get_parser()
|
||||
opt, unknown_opts = parser.parse_known_args(argv)
|
||||
@ -68,7 +68,7 @@ class InvokeAISettings(BaseSettings):
|
||||
return OmegaConf.to_yaml(conf)
|
||||
|
||||
@classmethod
|
||||
def add_parser_arguments(cls, parser):
|
||||
def add_parser_arguments(cls, parser) -> None:
|
||||
"""Dynamically create arguments for a settings parser."""
|
||||
if "type" in get_type_hints(cls):
|
||||
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
|
||||
@ -117,7 +117,8 @@ class InvokeAISettings(BaseSettings):
|
||||
"""Return the category of a setting."""
|
||||
hints = get_type_hints(cls)
|
||||
if command_field in hints:
|
||||
return get_args(hints[command_field])[0]
|
||||
result: str = get_args(hints[command_field])[0]
|
||||
return result
|
||||
else:
|
||||
return "Uncategorized"
|
||||
|
||||
@ -158,7 +159,7 @@ class InvokeAISettings(BaseSettings):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
|
||||
def add_field_argument(cls, command_parser, name: str, field, default_override=None) -> None:
|
||||
"""Add the argparse arguments for a setting parser."""
|
||||
field_type = get_type_hints(cls).get(name)
|
||||
default = (
|
||||
|
@ -21,7 +21,7 @@ class PagingArgumentParser(argparse.ArgumentParser):
|
||||
It also supports reading defaults from an init file.
|
||||
"""
|
||||
|
||||
def print_help(self, file=None):
|
||||
def print_help(self, file=None) -> None:
|
||||
text = self.format_help()
|
||||
pydoc.pager(text)
|
||||
|
||||
|
@ -8,12 +8,12 @@ import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from queue import Empty, PriorityQueue
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
import requests
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests import HTTPError
|
||||
from tqdm import tqdm
|
||||
from tqdm import tqdm, std
|
||||
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
@ -49,12 +49,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
|
||||
:param requests_session: Optional requests.sessions.Session object, for unit tests.
|
||||
"""
|
||||
self._jobs = {}
|
||||
self._jobs: Dict[int, DownloadJob] = {}
|
||||
self._next_job_id = 0
|
||||
self._queue = PriorityQueue()
|
||||
self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
|
||||
self._stop_event = threading.Event()
|
||||
self._job_completed_event = threading.Event()
|
||||
self._worker_pool = set()
|
||||
self._worker_pool: Set[threading.Thread] = set()
|
||||
self._lock = threading.Lock()
|
||||
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
|
||||
self._event_bus = event_bus
|
||||
@ -424,7 +424,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
class TqdmProgress(object):
|
||||
"""TQDM-based progress bar object to use in on_progress handlers."""
|
||||
|
||||
_bars: Dict[int, tqdm] # the tqdm object
|
||||
_bars: Dict[int, tqdm] # type: ignore
|
||||
_last: Dict[int, int] # last bytes downloaded
|
||||
|
||||
def __init__(self) -> None: # noqa D107
|
||||
|
@ -5,7 +5,7 @@ import uuid
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_timestamp():
|
||||
def get_timestamp() -> int:
|
||||
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
||||
|
||||
|
||||
@ -20,16 +20,16 @@ def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:
|
||||
SEED_MAX = np.iinfo(np.uint32).max
|
||||
|
||||
|
||||
def get_random_seed():
|
||||
def get_random_seed() -> int:
|
||||
rng = np.random.default_rng(seed=None)
|
||||
return int(rng.integers(0, SEED_MAX))
|
||||
|
||||
|
||||
def uuid_string():
|
||||
def uuid_string() -> str:
|
||||
res = uuid.uuid4()
|
||||
return str(res)
|
||||
|
||||
|
||||
def is_optional(value: typing.Any):
|
||||
def is_optional(value: typing.Any) -> bool:
|
||||
"""Checks if a value is typed as Optional. Note that Optional is sugar for Union[x, None]."""
|
||||
return typing.get_origin(value) is typing.Union and type(None) in typing.get_args(value)
|
||||
|
@ -22,6 +22,7 @@ from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelConfigBase,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
@ -70,7 +71,7 @@ class ModelLoaderBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
Return a model given its confguration.
|
||||
|
||||
@ -122,7 +123,7 @@ class AnyModelLoader:
|
||||
"""Return the convert cache associated used by the loaders."""
|
||||
return self._convert_cache
|
||||
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
Return a model given its configuration.
|
||||
|
||||
@ -144,8 +145,8 @@ class AnyModelLoader:
|
||||
|
||||
@classmethod
|
||||
def get_implementation(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[Type[ModelLoaderBase], AnyModelConfig, Optional[SubModelType]]:
|
||||
cls, config: ModelConfigBase, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
||||
"""Get subclass of ModelLoaderBase registered to handle base and type."""
|
||||
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
|
||||
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
|
||||
@ -161,8 +162,8 @@ class AnyModelLoader:
|
||||
|
||||
@classmethod
|
||||
def _handle_subtype_overrides(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[AnyModelConfig, Optional[SubModelType]]:
|
||||
cls, config: ModelConfigBase, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
|
||||
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
|
||||
model_path = Path(config.vae)
|
||||
config_class = (
|
||||
|
@ -34,8 +34,8 @@ from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot,
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase
|
||||
from .model_locker import ModelLocker, ModelLockerBase
|
||||
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
|
||||
from .model_locker import ModelLocker
|
||||
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
|
@ -20,7 +20,7 @@ from requests.sessions import Session
|
||||
|
||||
from invokeai.backend.model_manager import ModelRepoVariant
|
||||
|
||||
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator
|
||||
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator, BaseMetadata
|
||||
|
||||
|
||||
class ModelMetadataFetchBase(ABC):
|
||||
@ -62,5 +62,5 @@ class ModelMetadataFetchBase(ABC):
|
||||
@classmethod
|
||||
def from_json(cls, json: str) -> AnyModelRepoMetadata:
|
||||
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
|
||||
metadata = AnyModelRepoMetadataValidator.validate_json(json)
|
||||
metadata: BaseMetadata = AnyModelRepoMetadataValidator.validate_json(json) # type: ignore
|
||||
return metadata
|
||||
|
@ -166,7 +166,7 @@ class ModelProbe(object):
|
||||
fields["original_hash"] = fields.get("original_hash") or hash
|
||||
fields["current_hash"] = fields.get("current_hash") or hash
|
||||
|
||||
if format_type == ModelFormat.Diffusers:
|
||||
if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
|
||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||
|
||||
# additional fields needed for main and controlnet models
|
||||
|
@ -116,9 +116,9 @@ class ModelSearch(ModelSearchBase):
|
||||
# returns all models that have 'anime' in the path
|
||||
"""
|
||||
|
||||
models_found: Set[Path] = Field(default=None)
|
||||
scanned_dirs: Set[Path] = Field(default=None)
|
||||
pruned_paths: Set[Path] = Field(default=None)
|
||||
models_found: Optional[Set[Path]] = Field(default=None)
|
||||
scanned_dirs: Optional[Set[Path]] = Field(default=None)
|
||||
pruned_paths: Optional[Set[Path]] = Field(default=None)
|
||||
|
||||
def search_started(self) -> None:
|
||||
self.models_found = set()
|
||||
|
Loading…
Reference in New Issue
Block a user