From 5d4d0e795ccb85fbf9a46ae7d592d716978b8d18 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 5 Jul 2023 20:07:10 +1000 Subject: [PATCH 1/2] fix(mm): fix up mm service types --- .../app/services/model_manager_service.py | 44 ++++++++++--------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 8b99c4a174..55df31d9c2 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -2,22 +2,26 @@ from __future__ import annotations -import torch from abc import ABC, abstractmethod -from pathlib import Path -from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING from dataclasses import dataclass +from pathlib import Path +from types import ModuleType +from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union + +import torch -from invokeai.backend.model_management.model_manager import ( - ModelManager, - BaseModelType, - ModelType, - SubModelType, - ModelInfo, -) from invokeai.app.models.exceptions import CanceledException -from .config import InvokeAIAppConfig +from invokeai.backend.model_management.model_manager import (AddModelResult, + BaseModelType, + ModelInfo, + ModelManager, + ModelType, + SubModelType) +from invokeai.backend.model_management.models.base import \ + SchedulerPredictionType + from ...backend.util import choose_precision, choose_torch_device +from .config import InvokeAIAppConfig if TYPE_CHECKING: from ..invocations.baseinvocation import BaseInvocation, InvocationContext @@ -30,7 +34,7 @@ class ModelManagerServiceBase(ABC): def __init__( self, config: InvokeAIAppConfig, - logger: types.ModuleType, + logger: ModuleType, ): """ Initialize with the path to the models.yaml config file. @@ -137,9 +141,9 @@ class ModelManagerServiceBase(ABC): @abstractmethod def heuristic_import(self, - items_to_import: Set[str], - prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, - )->Dict[str, AddModelResult]: + items_to_import: set[str], + prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None, + )->dict[str, AddModelResult]: '''Import a list of paths, repo_ids or URLs. Returns the set of successfully imported items. :param items_to_import: Set of strings corresponding to models to be imported. @@ -159,7 +163,7 @@ class ModelManagerServiceBase(ABC): pass @abstractmethod - def commit(self, conf_file: Path = None) -> None: + def commit(self, conf_file: Optional[Path] = None) -> None: """ Write current configuration out to the indicated file. If no conf_file is provided, then replaces the @@ -173,7 +177,7 @@ class ModelManagerService(ModelManagerServiceBase): def __init__( self, config: InvokeAIAppConfig, - logger: types.ModuleType, + logger: ModuleType, ): """ Initialize with the path to the models.yaml config file. @@ -387,9 +391,9 @@ class ModelManagerService(ModelManagerServiceBase): return self.mgr.logger def heuristic_import(self, - items_to_import: Set[str], - prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, - )->Dict[str, AddModelResult]: + items_to_import: set[str], + prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None, + )->dict[str, AddModelResult]: '''Import a list of paths, repo_ids or URLs. Returns the set of successfully imported items. :param items_to_import: Set of strings corresponding to models to be imported. From 56d4ea3252474b40cc6b55f02feba698f1d2da12 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 5 Jul 2023 20:08:47 +1000 Subject: [PATCH 2/2] fix(api): improve mm routes --- invokeai/app/api/routers/models.py | 39 ++++++++++++++++++------------ 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 872bce30fe..e0b864f8fd 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -2,13 +2,18 @@ from typing import Literal, Optional, Union -from fastapi import Query, Body, Path -from fastapi.routing import APIRouter, HTTPException +from fastapi import Body, Path, Query, Response +from fastapi.routing import APIRouter from pydantic import BaseModel, Field, parse_obj_as -from ..dependencies import ApiDependencies +from starlette.exceptions import HTTPException + from invokeai.backend import BaseModelType, ModelType from invokeai.backend.model_management import AddModelResult -from invokeai.backend.model_management.models import MODEL_CONFIGS, OPENAPI_MODEL_CONFIGS, SchedulerPredictionType +from invokeai.backend.model_management.models import (MODEL_CONFIGS, + OPENAPI_MODEL_CONFIGS, + SchedulerPredictionType) + +from ..dependencies import ApiDependencies models_router = APIRouter(prefix="/v1/models", tags=["models"]) @@ -75,6 +80,7 @@ async def update_model( responses= { 201: {"description" : "The model imported successfully"}, 404: {"description" : "The model could not be found"}, + 424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"}, 409: {"description" : "There is already a model corresponding to this path or repo_id"}, }, status_code=201, @@ -96,9 +102,14 @@ async def import_model( items_to_import = items_to_import, prediction_type_helper = lambda x: prediction_types.get(prediction_type) ) - if info := installed_models.get(name): - logger.info(f'Successfully imported {name}, got {info}') - return ImportModelResponse( + info = installed_models.get(name) + + if not info: + logger.error("Import failed") + raise HTTPException(status_code=424) + + logger.info(f'Successfully imported {name}, got {info}') + return ImportModelResponse( name = name, info = info, status = "success", @@ -124,10 +135,10 @@ async def import_model( }, ) async def delete_model( - base_model: BaseModelType = Path(default='sd-1', description="Base model"), - model_type: ModelType = Path(default='main', description="The type of model"), - model_name: str = Path(default=None, description="model name"), -) -> None: + base_model: BaseModelType = Path(description="Base model"), + model_type: ModelType = Path(description="The type of model"), + model_name: str = Path(description="model name"), +) -> Response: """Delete Model""" logger = ApiDependencies.invoker.services.logger @@ -137,14 +148,10 @@ async def delete_model( model_type = model_type ) logger.info(f"Deleted model: {model_name}") - raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully") + return Response(status_code=204) except KeyError: logger.error(f"Model not found: {model_name}") raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") - else: - logger.info(f"Model deleted: {model_name}") - raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully") - # @socketio.on("convertToDiffusers") # def convert_to_diffusers(model_to_convert: dict):