mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add the import model router
This commit is contained in:
committed by
psychedelicious
parent
0988725c1b
commit
96bf92ead4
@ -2,17 +2,17 @@
|
||||
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from fastapi import Query
|
||||
from fastapi import Query, Body
|
||||
from fastapi.routing import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field, parse_obj_as
|
||||
from ..dependencies import ApiDependencies
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management import AddModelResult
|
||||
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
|
||||
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
|
||||
class VaeRepo(BaseModel):
|
||||
repo_id: str = Field(description="The repo ID to use for this VAE")
|
||||
path: Optional[str] = Field(description="The path to the VAE")
|
||||
@ -51,9 +51,12 @@ class CreateModelResponse(BaseModel):
|
||||
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
||||
status: str = Field(description="The status of the API response")
|
||||
|
||||
class ImportModelRequest(BaseModel):
|
||||
name: str = Field(description="A model path, repo_id or URL to import")
|
||||
prediction_type: Optional[Literal['epsilon','v_prediction','sample']] = Field(description='Prediction type for SDv2 checkpoint files')
|
||||
class ImportModelResponse(BaseModel):
|
||||
name: str = Field(description="The name of the imported model")
|
||||
# base_model: str = Field(description="The base model")
|
||||
# model_type: str = Field(description="The model type")
|
||||
info: AddModelResult = Field(description="The model info")
|
||||
status: str = Field(description="The status of the API response")
|
||||
|
||||
class ConversionRequest(BaseModel):
|
||||
name: str = Field(description="The name of the new model")
|
||||
@ -86,7 +89,6 @@ async def list_models(
|
||||
models = parse_obj_as(ModelsList, { "models": models_raw })
|
||||
return models
|
||||
|
||||
|
||||
@models_router.post(
|
||||
"/",
|
||||
operation_id="update_model",
|
||||
@ -109,27 +111,38 @@ async def update_model(
|
||||
return model_response
|
||||
|
||||
@models_router.post(
|
||||
"/",
|
||||
"/import",
|
||||
operation_id="import_model",
|
||||
responses={200: {"status": "success"}},
|
||||
responses= {
|
||||
201: {"description" : "The model imported successfully"},
|
||||
404: {"description" : "The model could not be found"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse
|
||||
)
|
||||
async def import_model(
|
||||
model_request: ImportModelRequest
|
||||
) -> None:
|
||||
""" Add Model """
|
||||
items_to_import = set([model_request.name])
|
||||
name: str = Query(description="A model path, repo_id or URL to import"),
|
||||
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = Query(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||
) -> ImportModelResponse:
|
||||
""" Add a model using its local path, repo_id, or remote URL """
|
||||
items_to_import = {name}
|
||||
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||
items_to_import = items_to_import,
|
||||
prediction_type_helper = lambda x: prediction_types.get(model_request.prediction_type)
|
||||
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
||||
)
|
||||
if len(installed_models) > 0:
|
||||
logger.info(f'Successfully imported {model_request.name}')
|
||||
if info := installed_models.get(name):
|
||||
logger.info(f'Successfully imported {name}, got {info}')
|
||||
return ImportModelResponse(
|
||||
name = name,
|
||||
info = info,
|
||||
status = "success",
|
||||
)
|
||||
else:
|
||||
logger.error(f'Model {model_request.name} not imported')
|
||||
raise HTTPException(status_code=500, detail=f'Model {model_request.name} not imported')
|
||||
logger.error(f'Model {name} not imported')
|
||||
raise HTTPException(status_code=404, detail=f'Model {name} not found')
|
||||
|
||||
@models_router.delete(
|
||||
"/{model_name}",
|
||||
|
Reference in New Issue
Block a user