more type mismatch fixes

This commit is contained in:
Lincoln Stein
2023-09-30 10:19:22 -04:00
parent 208d390779
commit 807ae821ea
5 changed files with 39 additions and 41 deletions

View File

@ -3,13 +3,14 @@
import pathlib
from enum import Enum
from typing import List, Literal, Optional, Union, Tuple
from typing import Any, List, Literal, Optional, Union
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic import BaseModel, parse_obj_as
from starlette.exceptions import HTTPException
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_manager import (
OPENAPI_MODEL_CONFIGS,
@ -22,8 +23,6 @@ from invokeai.backend.model_manager import (
from invokeai.backend.model_manager.download import DownloadJobStatus, UnknownJobIDException
from invokeai.backend.model_manager.merge import MergeInterpolationMethod
from invokeai.app.api.dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"])
# NOTE: The generic configuration classes defined in invokeai.backend.model_manager.config
@ -33,15 +32,11 @@ models_router = APIRouter(prefix="/v1/models", tags=["models"])
# There are still numerous mypy errors here because it does not seem to like this
# way of dynamically generating the typing hints below.
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
InvokeAIModelConfig: Any = Union[tuple(OPENAPI_MODEL_CONFIGS)]
class ModelsList(BaseModel):
models: List[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
models: List[InvokeAIModelConfig]
class ModelImportStatus(BaseModel):
@ -93,12 +88,12 @@ async def list_models(
409: {"description": "There is already a model corresponding to the new name"},
},
status_code=200,
response_model=UpdateModelResponse,
response_model=InvokeAIModelConfig,
)
async def update_model(
key: str = Path(description="Unique key of model"),
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> UpdateModelResponse:
info: InvokeAIModelConfig = Body(description="Model configuration"),
) -> InvokeAIModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger
@ -106,7 +101,7 @@ async def update_model(
info_dict = info.dict()
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
new_config = ApiDependencies.invoker.services.model_manager.update_model(key, new_config=info_dict)
model_response = parse_obj_as(UpdateModelResponse, new_config.dict())
model_response = parse_obj_as(InvokeAIModelConfig, new_config.dict())
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
@ -198,11 +193,11 @@ async def import_model(
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
response_model=ImportModelResponse,
response_model=InvokeAIModelConfig,
)
async def add_model(
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> ImportModelResponse:
info: InvokeAIModelConfig = Body(description="Model configuration"),
) -> InvokeAIModelConfig:
"""
Add a model using the configuration information appropriate for its type. Only local models can be added by path.
This call will block until the model is installed.
@ -220,7 +215,7 @@ async def add_model(
info_dict = info.dict()
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
new_config = ApiDependencies.invoker.services.model_manager.update_model(key, new_config=info_dict)
return parse_obj_as(ImportModelResponse, new_config.dict())
return parse_obj_as(InvokeAIModelConfig, new_config.dict())
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
@ -261,20 +256,20 @@ async def delete_model(
404: {"description": "Model not found"},
},
status_code=200,
response_model=ConvertModelResponse,
response_model=InvokeAIModelConfig,
)
async def convert_model(
key: str = Path(description="Unique key of model to remove from model registry."),
convert_dest_directory: Optional[str] = Query(
default=None, description="Save the converted model to the designated directory"
),
) -> ConvertModelResponse:
) -> InvokeAIModelConfig:
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
try:
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
ApiDependencies.invoker.services.model_manager.convert_model(key, convert_dest_directory=dest)
model_raw = ApiDependencies.invoker.services.model_manager.model_info(key).dict()
response = parse_obj_as(ConvertModelResponse, model_raw)
response = parse_obj_as(InvokeAIModelConfig, model_raw)
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=f"Model '{key}' not found: {str(e)}")
except ValueError as e:
@ -347,7 +342,7 @@ async def sync_to_config() -> bool:
409: {"description": "An identical merged model is already installed"},
},
status_code=200,
response_model=MergeModelResponse,
response_model=InvokeAIModelConfig,
)
async def merge_models(
keys: List[str] = Body(description="model name", min_items=2, max_items=3),
@ -361,7 +356,7 @@ async def merge_models(
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
default=None,
),
) -> MergeModelResponse:
) -> InvokeAIModelConfig:
"""Merge the indicated diffusers model."""
logger = ApiDependencies.invoker.services.logger
try:
@ -375,7 +370,7 @@ async def merge_models(
force=force,
merge_dest_directory=dest,
)
response = parse_obj_as(ConvertModelResponse, result.dict())
response = parse_obj_as(InvokeAIModelConfig, result.dict())
except DuplicateModelException as e:
raise HTTPException(status_code=409, detail=str(e))
except UnknownModelException:

View File

@ -26,9 +26,9 @@ from invokeai.backend.model_manager import (
from invokeai.backend.model_manager.cache import CacheStats
from invokeai.backend.model_manager.download import DownloadJobBase
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from .events import EventServiceBase
from .config import InvokeAIAppConfig
from .events import EventServiceBase
if TYPE_CHECKING:
from ..invocations.baseinvocation import InvocationContext

View File

@ -208,7 +208,7 @@ class ModelCache(object):
self.stats.hits += 1
if self.stats:
self.stats.cache_size = self.max_cache_size * GIG
self.stats.cache_size = int(self.max_cache_size * GIG)
self.stats.high_watermark = max(self.stats.high_watermark, self._cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[key] = max(

View File

@ -1,6 +1,6 @@
import inspect
from enum import Enum
from typing import Literal, get_origin
from typing import Any, Literal, get_origin
from pydantic import BaseModel
@ -89,8 +89,8 @@ MODEL_CLASSES = {
# },
}
MODEL_CONFIGS = list()
OPENAPI_MODEL_CONFIGS = list()
MODEL_CONFIGS: Any = list()
OPENAPI_MODEL_CONFIGS: Any = list()
class OpenAPIModelInfoBase(BaseModel):

View File

@ -2,14 +2,11 @@ import base64
import importlib
import io
import math
import multiprocessing as mp
import os
import re
from collections import abc
from inspect import isfunction
from pathlib import Path
from queue import Queue
from threading import Thread
from typing import Optional
import numpy as np
import requests
@ -166,7 +163,7 @@ def ask_user(question: str, answers: list):
# -------------------------------------
def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path:
def download_with_resume(url: str, dest: Path, access_token: str = None) -> Optional[Path]:
"""
Download a model file.
:param url: https, http or ftp URL
@ -183,10 +180,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
content_length = int(resp.headers.get("content-length", 0))
if dest.is_dir():
try:
file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1)
except AttributeError:
file_name = os.path.basename(url)
file_name = response_attachment(resp) or os.path.basename(url)
dest = dest / file_name
else:
dest.parent.mkdir(parents=True, exist_ok=True)
@ -235,15 +229,24 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
return dest
def url_attachment_name(url: str) -> dict:
def response_attachment(response: requests.Response) -> Optional[str]:
try:
resp = requests.get(url, stream=True)
match = re.search('filename="(.+)"', resp.headers.get("Content-Disposition"))
return match.group(1)
if disposition := response.headers.get("Content-Disposition"):
if match := re.search('filename="(.+)"', disposition):
return match.group(1)
return None
except Exception:
return None
def url_attachment_name(url: str) -> Optional[str]:
resp = requests.get(url)
if resp.ok:
return response_attachment(resp)
else:
return None
def download_with_progress_bar(url: str, dest: Path) -> bool:
result = download_with_resume(url, dest, access_token=None)
return result is not None