From 807ae821eaec7d14fb597f768e430689c63d124e Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 30 Sep 2023 10:19:22 -0400 Subject: [PATCH] more type mismatch fixes --- invokeai/app/api/routers/models.py | 41 ++++++++----------- .../app/services/model_manager_service.py | 2 +- invokeai/backend/model_manager/cache.py | 2 +- .../backend/model_manager/models/__init__.py | 6 +-- invokeai/backend/util/util.py | 29 +++++++------ 5 files changed, 39 insertions(+), 41 deletions(-) diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 2a3a06e928..b73feddb3e 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -3,13 +3,14 @@ import pathlib from enum import Enum -from typing import List, Literal, Optional, Union, Tuple +from typing import Any, List, Literal, Optional, Union from fastapi import Body, Path, Query, Response from fastapi.routing import APIRouter from pydantic import BaseModel, parse_obj_as from starlette.exceptions import HTTPException +from invokeai.app.api.dependencies import ApiDependencies from invokeai.backend import BaseModelType, ModelType from invokeai.backend.model_manager import ( OPENAPI_MODEL_CONFIGS, @@ -22,8 +23,6 @@ from invokeai.backend.model_manager import ( from invokeai.backend.model_manager.download import DownloadJobStatus, UnknownJobIDException from invokeai.backend.model_manager.merge import MergeInterpolationMethod -from invokeai.app.api.dependencies import ApiDependencies - models_router = APIRouter(prefix="/v1/models", tags=["models"]) # NOTE: The generic configuration classes defined in invokeai.backend.model_manager.config @@ -33,15 +32,11 @@ models_router = APIRouter(prefix="/v1/models", tags=["models"]) # There are still numerous mypy errors here because it does not seem to like this # way of dynamically generating the typing hints below. -UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] -ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] -ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] -MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] -ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)] +InvokeAIModelConfig: Any = Union[tuple(OPENAPI_MODEL_CONFIGS)] class ModelsList(BaseModel): - models: List[Union[tuple(OPENAPI_MODEL_CONFIGS)]] + models: List[InvokeAIModelConfig] class ModelImportStatus(BaseModel): @@ -93,12 +88,12 @@ async def list_models( 409: {"description": "There is already a model corresponding to the new name"}, }, status_code=200, - response_model=UpdateModelResponse, + response_model=InvokeAIModelConfig, ) async def update_model( key: str = Path(description="Unique key of model"), - info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), -) -> UpdateModelResponse: + info: InvokeAIModelConfig = Body(description="Model configuration"), +) -> InvokeAIModelConfig: """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" logger = ApiDependencies.invoker.services.logger @@ -106,7 +101,7 @@ async def update_model( info_dict = info.dict() info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()} new_config = ApiDependencies.invoker.services.model_manager.update_model(key, new_config=info_dict) - model_response = parse_obj_as(UpdateModelResponse, new_config.dict()) + model_response = parse_obj_as(InvokeAIModelConfig, new_config.dict()) except UnknownModelException as e: raise HTTPException(status_code=404, detail=str(e)) except ValueError as e: @@ -198,11 +193,11 @@ async def import_model( 409: {"description": "There is already a model corresponding to this path or repo_id"}, }, status_code=201, - response_model=ImportModelResponse, + response_model=InvokeAIModelConfig, ) async def add_model( - info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), -) -> ImportModelResponse: + info: InvokeAIModelConfig = Body(description="Model configuration"), +) -> InvokeAIModelConfig: """ Add a model using the configuration information appropriate for its type. Only local models can be added by path. This call will block until the model is installed. @@ -220,7 +215,7 @@ async def add_model( info_dict = info.dict() info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()} new_config = ApiDependencies.invoker.services.model_manager.update_model(key, new_config=info_dict) - return parse_obj_as(ImportModelResponse, new_config.dict()) + return parse_obj_as(InvokeAIModelConfig, new_config.dict()) except UnknownModelException as e: logger.error(str(e)) raise HTTPException(status_code=404, detail=str(e)) @@ -261,20 +256,20 @@ async def delete_model( 404: {"description": "Model not found"}, }, status_code=200, - response_model=ConvertModelResponse, + response_model=InvokeAIModelConfig, ) async def convert_model( key: str = Path(description="Unique key of model to remove from model registry."), convert_dest_directory: Optional[str] = Query( default=None, description="Save the converted model to the designated directory" ), -) -> ConvertModelResponse: +) -> InvokeAIModelConfig: """Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none.""" try: dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None ApiDependencies.invoker.services.model_manager.convert_model(key, convert_dest_directory=dest) model_raw = ApiDependencies.invoker.services.model_manager.model_info(key).dict() - response = parse_obj_as(ConvertModelResponse, model_raw) + response = parse_obj_as(InvokeAIModelConfig, model_raw) except UnknownModelException as e: raise HTTPException(status_code=404, detail=f"Model '{key}' not found: {str(e)}") except ValueError as e: @@ -347,7 +342,7 @@ async def sync_to_config() -> bool: 409: {"description": "An identical merged model is already installed"}, }, status_code=200, - response_model=MergeModelResponse, + response_model=InvokeAIModelConfig, ) async def merge_models( keys: List[str] = Body(description="model name", min_items=2, max_items=3), @@ -361,7 +356,7 @@ async def merge_models( description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None, ), -) -> MergeModelResponse: +) -> InvokeAIModelConfig: """Merge the indicated diffusers model.""" logger = ApiDependencies.invoker.services.logger try: @@ -375,7 +370,7 @@ async def merge_models( force=force, merge_dest_directory=dest, ) - response = parse_obj_as(ConvertModelResponse, result.dict()) + response = parse_obj_as(InvokeAIModelConfig, result.dict()) except DuplicateModelException as e: raise HTTPException(status_code=409, detail=str(e)) except UnknownModelException: diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 2422e7f3f0..64289d93ca 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -26,9 +26,9 @@ from invokeai.backend.model_manager import ( from invokeai.backend.model_manager.cache import CacheStats from invokeai.backend.model_manager.download import DownloadJobBase from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger -from .events import EventServiceBase from .config import InvokeAIAppConfig +from .events import EventServiceBase if TYPE_CHECKING: from ..invocations.baseinvocation import InvocationContext diff --git a/invokeai/backend/model_manager/cache.py b/invokeai/backend/model_manager/cache.py index 1d02e7d3bc..5d2ab1e182 100644 --- a/invokeai/backend/model_manager/cache.py +++ b/invokeai/backend/model_manager/cache.py @@ -208,7 +208,7 @@ class ModelCache(object): self.stats.hits += 1 if self.stats: - self.stats.cache_size = self.max_cache_size * GIG + self.stats.cache_size = int(self.max_cache_size * GIG) self.stats.high_watermark = max(self.stats.high_watermark, self._cache_size()) self.stats.in_cache = len(self._cached_models) self.stats.loaded_model_sizes[key] = max( diff --git a/invokeai/backend/model_manager/models/__init__.py b/invokeai/backend/model_manager/models/__init__.py index 1b082ed53d..1f8e1dd474 100644 --- a/invokeai/backend/model_manager/models/__init__.py +++ b/invokeai/backend/model_manager/models/__init__.py @@ -1,6 +1,6 @@ import inspect from enum import Enum -from typing import Literal, get_origin +from typing import Any, Literal, get_origin from pydantic import BaseModel @@ -89,8 +89,8 @@ MODEL_CLASSES = { # }, } -MODEL_CONFIGS = list() -OPENAPI_MODEL_CONFIGS = list() +MODEL_CONFIGS: Any = list() +OPENAPI_MODEL_CONFIGS: Any = list() class OpenAPIModelInfoBase(BaseModel): diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py index 2e6031a2c8..78d7410fc5 100644 --- a/invokeai/backend/util/util.py +++ b/invokeai/backend/util/util.py @@ -2,14 +2,11 @@ import base64 import importlib import io import math -import multiprocessing as mp import os import re -from collections import abc from inspect import isfunction from pathlib import Path -from queue import Queue -from threading import Thread +from typing import Optional import numpy as np import requests @@ -166,7 +163,7 @@ def ask_user(question: str, answers: list): # ------------------------------------- -def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path: +def download_with_resume(url: str, dest: Path, access_token: str = None) -> Optional[Path]: """ Download a model file. :param url: https, http or ftp URL @@ -183,10 +180,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path content_length = int(resp.headers.get("content-length", 0)) if dest.is_dir(): - try: - file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1) - except AttributeError: - file_name = os.path.basename(url) + file_name = response_attachment(resp) or os.path.basename(url) dest = dest / file_name else: dest.parent.mkdir(parents=True, exist_ok=True) @@ -235,15 +229,24 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path return dest -def url_attachment_name(url: str) -> dict: +def response_attachment(response: requests.Response) -> Optional[str]: try: - resp = requests.get(url, stream=True) - match = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")) - return match.group(1) + if disposition := response.headers.get("Content-Disposition"): + if match := re.search('filename="(.+)"', disposition): + return match.group(1) + return None except Exception: return None +def url_attachment_name(url: str) -> Optional[str]: + resp = requests.get(url) + if resp.ok: + return response_attachment(resp) + else: + return None + + def download_with_progress_bar(url: str, dest: Path) -> bool: result = download_with_resume(url, dest, access_token=None) return result is not None