revert(mm): restore convert route

This commit is contained in:
psychedelicious 2024-03-05 18:40:17 +11:00
parent 7c9128b253
commit 48119d9010

View File

@ -2,6 +2,7 @@
"""FastAPI route for model configuration records."""
import pathlib
import shutil
from typing import Any, Dict, List, Optional
from fastapi import Body, Path, Query, Response
@ -15,12 +16,14 @@ from invokeai.app.services.model_records import (
InvalidModelException,
UnknownModelException,
)
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.app.services.model_records.model_records_base import DuplicateModelException, ModelRecordChanges
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
MainCheckpointConfig,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.search import ModelSearch
@ -465,8 +468,8 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
installer.cancel_job(job)
@model_manager_router.patch(
"/install/prune",
@model_manager_router.delete(
"/install",
operation_id="prune_model_install_jobs",
responses={
204: {"description": "All completed and errored jobs have been pruned"},
@ -498,80 +501,80 @@ async def sync_models_to_config() -> Response:
return Response(status_code=204)
# @model_manager_router.put(
# "/convert/{key}",
# operation_id="convert_model",
# responses={
# 200: {
# "description": "Model converted successfully",
# "content": {"application/json": {"example": example_model_config}},
# },
# 400: {"description": "Bad request"},
# 404: {"description": "Model not found"},
# 409: {"description": "There is already a model registered at this location"},
# },
# )
# async def convert_model(
# key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."),
# ) -> AnyModelConfig:
# """
# Permanently convert a model into diffusers format, replacing the safetensors version.
# Note that during the conversion process the key and model hash will change.
# The return value is the model configuration for the converted model.
# """
# model_manager = ApiDependencies.invoker.services.model_manager
# logger = ApiDependencies.invoker.services.logger
# loader = ApiDependencies.invoker.services.model_manager.load
# store = ApiDependencies.invoker.services.model_manager.store
# installer = ApiDependencies.invoker.services.model_manager.install
@model_manager_router.put(
"/convert/{key}",
operation_id="convert_model",
responses={
200: {
"description": "Model converted successfully",
"content": {"application/json": {"example": example_model_config}},
},
400: {"description": "Bad request"},
404: {"description": "Model not found"},
409: {"description": "There is already a model registered at this location"},
},
)
async def convert_model(
key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."),
) -> AnyModelConfig:
"""
Permanently convert a model into diffusers format, replacing the safetensors version.
Note that during the conversion process the key and model hash will change.
The return value is the model configuration for the converted model.
"""
model_manager = ApiDependencies.invoker.services.model_manager
logger = ApiDependencies.invoker.services.logger
loader = ApiDependencies.invoker.services.model_manager.load
store = ApiDependencies.invoker.services.model_manager.store
installer = ApiDependencies.invoker.services.model_manager.install
# try:
# model_config = store.get_model(key)
# except UnknownModelException as e:
# logger.error(str(e))
# raise HTTPException(status_code=424, detail=str(e))
try:
model_config = store.get_model(key)
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=424, detail=str(e))
# if not isinstance(model_config, MainCheckpointConfig):
# logger.error(f"The model with key {key} is not a main checkpoint model.")
# raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
if not isinstance(model_config, MainCheckpointConfig):
logger.error(f"The model with key {key} is not a main checkpoint model.")
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
# # loading the model will convert it into a cached diffusers file
# model_manager.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
# loading the model will convert it into a cached diffusers file
model_manager.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
# # Get the path of the converted model from the loader
# cache_path = loader.convert_cache.cache_path(key)
# assert cache_path.exists()
# Get the path of the converted model from the loader
cache_path = loader.convert_cache.cache_path(key)
assert cache_path.exists()
# # temporarily rename the original safetensors file so that there is no naming conflict
# original_name = model_config.name
# model_config.name = f"{original_name}.DELETE"
# changes = ModelRecordChanges(name=model_config.name)
# store.update_model(key, changes=changes)
# temporarily rename the original safetensors file so that there is no naming conflict
original_name = model_config.name
model_config.name = f"{original_name}.DELETE"
changes = ModelRecordChanges(name=model_config.name)
store.update_model(key, changes=changes)
# # install the diffusers
# try:
# new_key = installer.install_path(
# cache_path,
# config={
# "name": original_name,
# "description": model_config.description,
# "hash": model_config.hash,
# "source": model_config.source,
# },
# )
# except DuplicateModelException as e:
# logger.error(str(e))
# raise HTTPException(status_code=409, detail=str(e))
# install the diffusers
try:
new_key = installer.install_path(
cache_path,
config={
"name": original_name,
"description": model_config.description,
"hash": model_config.hash,
"source": model_config.source,
},
)
except DuplicateModelException as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
# # delete the original safetensors file
# installer.delete(key)
# delete the original safetensors file
installer.delete(key)
# # delete the cached version
# shutil.rmtree(cache_path)
# delete the cached version
shutil.rmtree(cache_path)
# # return the config record for the new diffusers directory
# new_config: AnyModelConfig = store.get_model(new_key)
# return new_config
# return the config record for the new diffusers directory
new_config: AnyModelConfig = store.get_model(new_key)
return new_config
# @model_manager_router.put(