convert implemented; need router

This commit is contained in:
Lincoln Stein 2023-07-05 09:05:05 -04:00
parent 5d099f4a49
commit 6112197edf
4 changed files with 106 additions and 5 deletions

View File

@ -22,6 +22,11 @@ class ImportModelResponse(BaseModel):
info: AddModelResult = Field(description="The model info")
status: str = Field(description="The status of the API response")
class ConvertModelResponse(BaseModel):
name: str = Field(description="The name of the imported model")
info: AddModelResult = Field(description="The model info")
status: str = Field(description="The status of the API response")
class ModelsList(BaseModel):
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]

View File

@ -5,8 +5,7 @@ from __future__ import annotations
import torch
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING
from dataclasses import dataclass
from typing import Optional, Union, Callable, List, Set, Dict, Tuple, types, TYPE_CHECKING
from invokeai.backend.model_management.model_manager import (
ModelManager,
@ -14,6 +13,8 @@ from invokeai.backend.model_management.model_manager import (
ModelType,
SubModelType,
ModelInfo,
AddModelResult,
SchedulerPredictionType,
)
from invokeai.app.models.exceptions import CanceledException
from .config import InvokeAIAppConfig
@ -111,7 +112,7 @@ class ModelManagerServiceBase(ABC):
model_type: ModelType,
model_attributes: dict,
clobber: bool = False
) -> None:
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
@ -135,6 +136,27 @@ class ModelManagerServiceBase(ABC):
"""
pass
@abstractmethod
def convert_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: Union[ModelType.Main,ModelType.Vae],
) -> AddModelResult:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
version and deleting the original checkpoint file if it is in the models
directory.
:param model_name: Name of the model to convert
:param base_model: Base model type
:param model_type: Type of model ['vae' or 'main']
This will raise a ValueError unless the model is not a checkpoint. It will
also raise a ValueError in the event that there is a similarly-named diffusers
directory already in place.
"""
pass
@abstractmethod
def heuristic_import(self,
items_to_import: Set[str],
@ -336,6 +358,26 @@ class ModelManagerService(ModelManagerServiceBase):
self.logger.debug(f'delete model {model_name}')
self.mgr.del_model(model_name, base_model, model_type)
def convert_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: Union[ModelType.Main,ModelType.Vae],
) -> AddModelResult:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
version and deleting the original checkpoint file if it is in the models
directory.
:param model_name: Name of the model to convert
:param base_model: Base model type
:param model_type: Type of model ['vae' or 'main']
This will raise a ValueError unless the model is not a checkpoint. It will
also raise a ValueError in the event that there is a similarly-named diffusers
directory already in place.
"""
self.logger.debug(f'convert model {model_name}')
return self.mgr.convert_model(model_name, base_model, model_type)
def commit(self, conf_file: Optional[Path]=None):
"""

View File

@ -234,7 +234,7 @@ import textwrap
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types
from shutil import rmtree
from shutil import rmtree, move
import torch
from omegaconf import OmegaConf
@ -620,6 +620,60 @@ class ModelManager(object):
config = model_config,
)
def convert_model (
self,
model_name: str,
base_model: BaseModelType,
model_type: Union[ModelType.Main,ModelType.Vae],
) -> AddModelResult:
'''
Convert a checkpoint file into a diffusers folder, deleting the cached
version and deleting the original checkpoint file if it is in the models
directory.
:param model_name: Name of the model to convert
:param base_model: Base model type
:param model_type: Type of model ['vae' or 'main']
This will raise a ValueError unless the model is a checkpoint.
'''
info = self.model_info(model_name, base_model, model_type)
if info["model_format"] != "checkpoint":
raise ValueError(f"not a checkpoint format model: {model_name}")
# We are taking advantage of a side effect of get_model() that converts check points
# into cached diffusers directories stored at `location`. It doesn't matter
# what submodeltype we request here, so we get the smallest.
submodel = {"submodel_type": SubModelType.Tokenizer} if model_type==ModelType.Main else {}
model = self.get_model(model_name,
base_model,
model_type,
**submodel,
)
checkpoint_path = self.app_config.root_path / info["path"]
old_diffusers_path = self.app_config.models_path / model.location
new_diffusers_path = self.app_config.models_path / base_model / model_type / model_name
if new_diffusers_path.exists():
raise ValueError(f"A diffusers model already exists at {new_path}")
try:
move(old_diffusers_path,new_diffusers_path)
info["model_format"] = "diffusers"
info["path"] = str(new_diffusers_path.relative_to(self.app_config.root_path))
info.pop('config')
result = self.add_model(model_name, base_model, model_type,
model_attributes = info,
clobber=True)
except:
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
rmtree(new_diffusers_path)
raise
if checkpoint_path.exists() and checkpoint_path.is_relative_to(self.app_config.models_path):
checkpoint_path.unlink()
return result
def search_models(self, search_folder):
self.logger.info(f"Finding Models In: {search_folder}")
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")

View File

@ -116,7 +116,7 @@ class StableDiffusion1Model(DiffusersModel):
version=BaseModelType.StableDiffusion1,
model_config=config,
output_path=output_path,
)
)
else:
return model_path