diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 343af89023..a4360a3285 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -2,25 +2,28 @@ from typing import Literal, Optional, Union -from fastapi import Query, Body, Path -from fastapi.routing import APIRouter, HTTPException +from fastapi import Body, Path, Query, Response +from fastapi.routing import APIRouter from pydantic import BaseModel, Field, parse_obj_as -from ..dependencies import ApiDependencies +from starlette.exceptions import HTTPException + from invokeai.backend import BaseModelType, ModelType from invokeai.backend.model_management import AddModelResult -from invokeai.backend.model_management.models import MODEL_CONFIGS, OPENAPI_MODEL_CONFIGS, SchedulerPredictionType +from invokeai.backend.model_management.models import (MODEL_CONFIGS, + OPENAPI_MODEL_CONFIGS, + SchedulerPredictionType) + +from ..dependencies import ApiDependencies models_router = APIRouter(prefix="/v1/models", tags=["models"]) -class CreateModelResponse(BaseModel): +class UpdateModelResponse(BaseModel): model_name: str = Field(description="The name of the new model") info: Union[tuple(MODEL_CONFIGS)] = Field(description="The model info") - status: str = Field(description="The status of the API response") class ImportModelResponse(BaseModel): - name: str = Field(description="The name of the imported model") + location: str = Field(description="The path, repo_id or URL of the imported model") info: AddModelResult = Field(description="The model info") - status: str = Field(description="The status of the API response") class ConvertModelResponse(BaseModel): name: str = Field(description="The name of the imported model") @@ -48,51 +51,65 @@ async def list_models( models = parse_obj_as(ModelsList, { "models": models_raw }) return models -@models_router.post( +@models_router.patch( "/{base_model}/{model_type}/{model_name}", operation_id="update_model", - responses={200: {"status": "success"}}, + responses={200: {"description" : "The model was updated successfully"}, + 404: {"description" : "The model could not be found"}, + 400: {"description" : "Bad request"} + }, + status_code = 200, + response_model = UpdateModelResponse, ) async def update_model( base_model: BaseModelType = Path(default='sd-1', description="Base model"), model_type: ModelType = Path(default='main', description="The type of model"), model_name: str = Path(default=None, description="model name"), info: Union[tuple(MODEL_CONFIGS)] = Body(description="Model configuration"), -) -> CreateModelResponse: +) -> UpdateModelResponse: """ Add Model """ - ApiDependencies.invoker.services.model_manager.add_model( - model_name=model_name, - base_model=base_model, - model_type=model_type, - model_attributes=info.dict(), - clobber=True, - ) - model_response = CreateModelResponse( - model_name = model_name, - info = info, - status="success") + try: + ApiDependencies.invoker.services.model_manager.update_model( + model_name=model_name, + base_model=base_model, + model_type=model_type, + model_attributes=info.dict() + ) + model_response = UpdateModelResponse( + model_name = model_name, + info = ApiDependencies.invoker.services.model_manager.model_info( + model_name=model_name, + base_model=base_model, + model_type=model_type, + ) + ) + except KeyError as e: + raise HTTPException(status_code=404, detail=str(e)) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) return model_response @models_router.post( - "/import", + "/", operation_id="import_model", responses= { 201: {"description" : "The model imported successfully"}, 404: {"description" : "The model could not be found"}, + 424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"}, 409: {"description" : "There is already a model corresponding to this path or repo_id"}, }, status_code=201, response_model=ImportModelResponse ) async def import_model( - name: str = Body(description="A model path, repo_id or URL to import"), + location: str = Body(description="A model path, repo_id or URL to import"), prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \ Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), ) -> ImportModelResponse: """ Add a model using its local path, repo_id, or remote URL """ - items_to_import = {name} + items_to_import = {location} prediction_types = { x.value: x for x in SchedulerPredictionType } logger = ApiDependencies.invoker.services.logger @@ -101,12 +118,16 @@ async def import_model( items_to_import = items_to_import, prediction_type_helper = lambda x: prediction_types.get(prediction_type) ) - if info := installed_models.get(name): - logger.info(f'Successfully imported {name}, got {info}') - return ImportModelResponse( - name = name, + info = installed_models.get(location) + + if not info: + logger.error("Import failed") + raise HTTPException(status_code=424) + + logger.info(f'Successfully imported {location}, got {info}') + return ImportModelResponse( + location = location, info = info, - status = "success", ) except KeyError as e: logger.error(str(e)) @@ -129,10 +150,10 @@ async def import_model( }, ) async def delete_model( - base_model: BaseModelType = Path(default='sd-1', description="Base model"), - model_type: ModelType = Path(default='main', description="The type of model"), - model_name: str = Path(default=None, description="model name"), -) -> None: + base_model: BaseModelType = Path(description="Base model"), + model_type: ModelType = Path(description="The type of model"), + model_name: str = Path(description="model name"), +) -> Response: """Delete Model""" logger = ApiDependencies.invoker.services.logger @@ -142,14 +163,10 @@ async def delete_model( model_type = model_type ) logger.info(f"Deleted model: {model_name}") - raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully") + return Response(status_code=204) except KeyError: logger.error(f"Model not found: {model_name}") raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") - else: - logger.info(f"Model deleted: {model_name}") - raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully") - # @socketio.on("convertToDiffusers") # def convert_to_diffusers(model_to_convert: dict): diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 65c0510cb5..0834b11559 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -2,10 +2,10 @@ from __future__ import annotations -import torch from abc import ABC, abstractmethod from pathlib import Path -from typing import Optional, Union, Callable, List, Set, Dict, Tuple, types, TYPE_CHECKING +from typing import Optional, Union, Callable, List, Tuple, TYPE_CHECKING +from types import ModuleType from invokeai.backend.model_management.model_manager import ( ModelManager, @@ -16,9 +16,11 @@ from invokeai.backend.model_management.model_manager import ( AddModelResult, SchedulerPredictionType, ) + +import torch from invokeai.app.models.exceptions import CanceledException -from .config import InvokeAIAppConfig from ...backend.util import choose_precision, choose_torch_device +from .config import InvokeAIAppConfig if TYPE_CHECKING: from ..invocations.baseinvocation import BaseInvocation, InvocationContext @@ -31,7 +33,7 @@ class ModelManagerServiceBase(ABC): def __init__( self, config: InvokeAIAppConfig, - logger: types.ModuleType, + logger: ModuleType, ): """ Initialize with the path to the models.yaml config file. @@ -122,6 +124,24 @@ class ModelManagerServiceBase(ABC): """ pass + @abstractmethod + def update_model( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + model_attributes: dict, + ) -> AddModelResult: + """ + Update the named model with a dictionary of attributes. Will fail with a + KeyErrorException if the name does not already exist. + + On a successful update, the config will be changed in memory. Will fail + with an assertion error if provided attributes are incorrect or + the model name is missing. Call commit() to write changes to disk. + """ + pass + @abstractmethod def del_model( self, @@ -159,9 +179,9 @@ class ModelManagerServiceBase(ABC): @abstractmethod def heuristic_import(self, - items_to_import: Set[str], - prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, - )->Dict[str, AddModelResult]: + items_to_import: set[str], + prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None, + )->dict[str, AddModelResult]: '''Import a list of paths, repo_ids or URLs. Returns the set of successfully imported items. :param items_to_import: Set of strings corresponding to models to be imported. @@ -181,7 +201,7 @@ class ModelManagerServiceBase(ABC): pass @abstractmethod - def commit(self, conf_file: Path = None) -> None: + def commit(self, conf_file: Optional[Path] = None) -> None: """ Write current configuration out to the indicated file. If no conf_file is provided, then replaces the @@ -195,7 +215,7 @@ class ModelManagerService(ModelManagerServiceBase): def __init__( self, config: InvokeAIAppConfig, - logger: types.ModuleType, + logger: ModuleType, ): """ Initialize with the path to the models.yaml config file. @@ -343,7 +363,25 @@ class ModelManagerService(ModelManagerServiceBase): self.logger.debug(f'add/update model {model_name}') return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber) - + def update_model( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + model_attributes: dict, + ) -> AddModelResult: + """ + Update the named model with a dictionary of attributes. Will fail with a + KeyError exception if the name does not already exist. + On a successful update, the config will be changed in memory. Will fail + with an assertion error if provided attributes are incorrect or + the model name is missing. Call commit() to write changes to disk. + """ + self.logger.debug(f'update model {model_name}') + if not self.model_exists(model_name, base_model, model_type): + raise KeyError(f"Unknown model {model_name}") + return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True) + def del_model( self, model_name: str, @@ -429,9 +467,9 @@ class ModelManagerService(ModelManagerServiceBase): return self.mgr.logger def heuristic_import(self, - items_to_import: Set[str], - prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, - )->Dict[str, AddModelResult]: + items_to_import: set[str], + prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None, + )->dict[str, AddModelResult]: '''Import a list of paths, repo_ids or URLs. Returns the set of successfully imported items. :param items_to_import: Set of strings corresponding to models to be imported. diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 0aab394318..28183a12e9 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -480,7 +480,7 @@ class ModelManager(object): """ model_key = self.create_key(model_name, base_model, model_type) if model_key in self.models: - return self.models[model_key].dict(exclude_defaults=True) + return self.models[model_key].dict(exclude_defaults=True, exclude={"error"}) else: return None # TODO: None or empty dict on not found