all methods now return OPENAPI_MODEL_CONFIGS; convert uses PUT

This commit is contained in:
Lincoln Stein 2023-07-05 23:13:01 -04:00
parent 3691b55565
commit f7daa6e71d
4 changed files with 71 additions and 42 deletions

View File

@ -1,16 +1,15 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
from fastapi import Body, Path, Query, Response from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from pydantic import BaseModel, Field, parse_obj_as from pydantic import BaseModel, parse_obj_as
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from invokeai.backend import BaseModelType, ModelType from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management import AddModelResult
from invokeai.backend.model_management.models import ( from invokeai.backend.model_management.models import (
MODEL_CONFIGS,
OPENAPI_MODEL_CONFIGS, OPENAPI_MODEL_CONFIGS,
SchedulerPredictionType SchedulerPredictionType
) )
@ -19,13 +18,9 @@ from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"]) models_router = APIRouter(prefix="/v1/models", tags=["models"])
class UpdateModelResponse(BaseModel): UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
model_name: str = Field(description="The name of the new model") ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
info: Union[tuple(MODEL_CONFIGS)] = Field(description="The model info") ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
class ImportModelResponse(BaseModel):
location: str = Field(description="The path, repo_id or URL of the imported model")
info: AddModelResult = Field(description="The model info")
class ModelsList(BaseModel): class ModelsList(BaseModel):
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
@ -62,7 +57,7 @@ async def update_model(
base_model: BaseModelType = Path(default='sd-1', description="Base model"), base_model: BaseModelType = Path(default='sd-1', description="Base model"),
model_type: ModelType = Path(default='main', description="The type of model"), model_type: ModelType = Path(default='main', description="The type of model"),
model_name: str = Path(default=None, description="model name"), model_name: str = Path(default=None, description="model name"),
info: Union[tuple(MODEL_CONFIGS)] = Body(description="Model configuration"), info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> UpdateModelResponse: ) -> UpdateModelResponse:
""" Add Model """ """ Add Model """
try: try:
@ -72,14 +67,12 @@ async def update_model(
model_type=model_type, model_type=model_type,
model_attributes=info.dict() model_attributes=info.dict()
) )
model_response = UpdateModelResponse( model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name = model_name, model_name=model_name,
info = ApiDependencies.invoker.services.model_manager.model_info( base_model=base_model,
model_name=model_name, model_type=model_type,
base_model=base_model,
model_type=model_type,
)
) )
model_response = parse_obj_as(UpdateModelResponse, model_raw)
except KeyError as e: except KeyError as e:
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
except ValueError as e: except ValueError as e:
@ -122,10 +115,13 @@ async def import_model(
raise HTTPException(status_code=424) raise HTTPException(status_code=424)
logger.info(f'Successfully imported {location}, got {info}') logger.info(f'Successfully imported {location}, got {info}')
return ImportModelResponse( model_raw = ApiDependencies.invoker.services.model_manager.list_model(
location = location, model_name=info.name,
info = info, base_model=info.base_model,
model_type=info.model_type
) )
return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e: except KeyError as e:
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@ -165,7 +161,7 @@ async def delete_model(
logger.error(f"Model not found: {model_name}") logger.error(f"Model not found: {model_name}")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
@models_router.patch( @models_router.put(
"/convert/{base_model}/{model_type}/{model_name}", "/convert/{base_model}/{model_type}/{model_name}",
operation_id="convert_model", operation_id="convert_model",
responses={ responses={
@ -174,26 +170,30 @@ async def delete_model(
404: { "description": "Model not found" }, 404: { "description": "Model not found" },
}, },
status_code = 200, status_code = 200,
response_model = Union[tuple(MODEL_CONFIGS)], response_model = Union[tuple(OPENAPI_MODEL_CONFIGS)],
) )
async def convert_model( async def convert_model(
base_model: BaseModelType = Path(description="Base model"), base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"), model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"), model_name: str = Path(description="model name"),
) -> Union[tuple(MODEL_CONFIGS)]: ) -> ConvertModelResponse:
"""Convert a checkpoint model into a diffusers model""" """Convert a checkpoint model into a diffusers model"""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:
logger.info(f"Converting model: {model_name}") logger.info(f"Converting model: {model_name}")
result = ApiDependencies.invoker.services.model_manager.convert_model(model_name, ApiDependencies.invoker.services.model_manager.convert_model(model_name,
base_model = base_model,
model_type = model_type
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
base_model = base_model, base_model = base_model,
model_type = model_type model_type = model_type)
) response = parse_obj_as(ConvertModelResponse, model_raw)
except KeyError: except KeyError:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
return result.config return response
# @socketio.on("mergeDiffusersModels") # @socketio.on("mergeDiffusersModels")
# def merge_diffusers_models(model_merge_info: dict): # def merge_diffusers_models(model_merge_info: dict):

View File

@ -76,13 +76,7 @@ class ModelManagerServiceBase(ABC):
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
""" """
Given a model name returns a dict-like (OmegaConf) object describing it. Given a model name returns a dict-like (OmegaConf) object describing it.
""" Uses the exact format as the omegaconf stanza.
pass
@abstractmethod
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
"""
Returns a list of all the model names known.
""" """
pass pass
@ -104,7 +98,20 @@ class ModelManagerServiceBase(ABC):
} }
""" """
pass pass
@abstractmethod
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Return information about the model using the same format as list_models()
"""
pass
@abstractmethod
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
"""
Returns a list of all the model names known.
"""
pass
@abstractmethod @abstractmethod
def add_model( def add_model(
@ -339,12 +346,19 @@ class ModelManagerService(ModelManagerServiceBase):
base_model: Optional[BaseModelType] = None, base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None model_type: Optional[ModelType] = None
) -> list[dict]: ) -> list[dict]:
# ) -> dict:
""" """
Return a list of models. Return a list of models.
""" """
return self.mgr.list_models(base_model, model_type) return self.mgr.list_models(base_model, model_type)
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Return information about the model using the same format as list_models()
"""
return self.mgr.list_model(model_name=model_name,
base_model=base_model,
model_type=model_type)
def add_model( def add_model(
self, self,
model_name: str, model_name: str,

View File

@ -480,7 +480,7 @@ class ModelManager(object):
""" """
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, base_model, model_type)
if model_key in self.models: if model_key in self.models:
return self.models[model_key].dict(exclude_defaults=True, exclude={"error"}) return self.models[model_key].dict(exclude_defaults=True)
else: else:
return None # TODO: None or empty dict on not found return None # TODO: None or empty dict on not found
@ -491,17 +491,32 @@ class ModelManager(object):
""" """
return [(self.parse_key(x)) for x in self.models.keys()] return [(self.parse_key(x)) for x in self.models.keys()]
def list_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
)->dict:
"""
Returns a dict describing one installed model, using
the combined format of the list_models() method.
"""
models = self.list_models(base_model,model_type,model_name)
return models[0] if models else None
def list_models( def list_models(
self, self,
base_model: Optional[BaseModelType] = None, base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None, model_type: Optional[ModelType] = None,
model_name: Optional[str] = None,
) -> list[dict]: ) -> list[dict]:
""" """
Return a list of models. Return a list of models.
""" """
model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold)
models = [] models = []
for model_key in sorted(self.models, key=str.casefold): for model_key in model_keys:
model_config = self.models[model_key] model_config = self.models[model_key]
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key) cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
@ -653,7 +668,7 @@ class ModelManager(object):
old_diffusers_path = self.app_config.models_path / model.location old_diffusers_path = self.app_config.models_path / model.location
new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name
if new_diffusers_path.exists(): if new_diffusers_path.exists():
raise ValueError(f"A diffusers model already exists at {new_path}") raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
try: try:
move(old_diffusers_path,new_diffusers_path) move(old_diffusers_path,new_diffusers_path)

View File

@ -26,7 +26,7 @@ class MergeInterpolationMethod(str, Enum):
def merge_diffusion_models( def merge_diffusion_models(
model_paths: List[Path], model_paths: List[Path],
alpha: float = 0.5, alpha: float = 0.5,
interp: InterpolationMethod = None, interp: MergeInterpolationMethod = None,
force: bool = False, force: bool = False,
**kwargs, **kwargs,
) -> DiffusionPipeline: ) -> DiffusionPipeline:
@ -67,7 +67,7 @@ def merge_diffusion_models_and_save (
merged_model_name: str, merged_model_name: str,
config: InvokeAIAppConfig, config: InvokeAIAppConfig,
alpha: float = 0.5, alpha: float = 0.5,
interp: InterpolationMethod = None, interp: MergeInterpolationMethod = None,
force: bool = False, force: bool = False,
**kwargs, **kwargs,
): ):