all methods in router API now tested and working

This commit is contained in:
Lincoln Stein
2023-09-16 19:43:01 -04:00
parent dc683475d4
commit c029534243
12 changed files with 165 additions and 95 deletions

View File

@ -2,7 +2,6 @@
import pathlib
import traceback
from typing import List, Literal, Optional, Union
from fastapi import Body, Path, Query, Response
@ -13,12 +12,13 @@ from starlette.exceptions import HTTPException
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_manager import (
OPENAPI_MODEL_CONFIGS,
DuplicateModelException,
InvalidModelException,
MergeInterpolationMethod,
ModelConfigBase,
SchedulerPredictionType,
UnknownModelException,
)
from invokeai.backend.model_manager.merge import MergeInterpolationMethod
from ..dependencies import ApiDependencies
@ -225,15 +225,13 @@ async def convert_model(
),
) -> ConvertModelResponse:
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
logger = ApiDependencies.invoker.services.logger
info = ApiDependencies.invoker.services.model_manager.model_info(key)
try:
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
ApiDependencies.invoker.services.model_manager.convert_model(key, convert_dest_directory=dest)
model_raw = ApiDependencies.invoker.services.model_manager.model_info(key).dict()
response = parse_obj_as(ConvertModelResponse, model_raw)
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
raise HTTPException(status_code=404, detail=f"Model '{key}' not found: {str(e)}")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response
@ -252,6 +250,7 @@ async def convert_model(
async def search_for_models(
search_path: pathlib.Path = Query(description="Directory path to search for models"),
) -> List[pathlib.Path]:
"""Search for all models in a server-local path."""
if not search_path.is_dir():
raise HTTPException(
status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory"
@ -283,27 +282,31 @@ async def list_ckpt_configs() -> List[pathlib.Path]:
response_model=bool,
)
async def sync_to_config() -> bool:
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
in-memory data structures with disk data structures."""
"""
Synchronize model in-memory data structures with disk.
Call after making changes to models.yaml, autoimport directories
or models directory.
"""
ApiDependencies.invoker.services.model_manager.sync_to_config()
return True
@models_router.put(
"/merge/{base_model}",
"/merge",
operation_id="merge_models",
responses={
200: {"description": "Model converted successfully"},
400: {"description": "Incompatible models"},
404: {"description": "One or more models not found"},
409: {"description": "An identical merged model is already installed"},
},
status_code=200,
response_model=MergeModelResponse,
)
async def merge_models(
base_model: BaseModelType = Path(description="Base model"),
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
merged_model_name: Optional[str] = Body(description="Name of destination model"),
keys: List[str] = Body(description="model name", min_items=2, max_items=3),
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
force: Optional[bool] = Body(
@ -314,28 +317,24 @@ async def merge_models(
default=None,
),
) -> MergeModelResponse:
"""Convert a checkpoint model into a diffusers model"""
"""Merge the indicated diffusers model."""
logger = ApiDependencies.invoker.services.logger
try:
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
result = ApiDependencies.invoker.services.model_manager.merge_models(
model_names,
base_model,
merged_model_name=merged_model_name or "+".join(model_names),
result: ModelConfigBase = ApiDependencies.invoker.services.model_manager.merge_models(
model_keys=keys,
merged_model_name=merged_model_name,
alpha=alpha,
interp=interp,
force=force,
merge_dest_directory=dest,
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
result.name,
base_model=base_model,
model_type=ModelType.Main,
)
response = parse_obj_as(ConvertModelResponse, model_raw)
response = parse_obj_as(ConvertModelResponse, result.dict())
except DuplicateModelException as e:
raise HTTPException(status_code=409, detail=str(e))
except UnknownModelException:
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
raise HTTPException(status_code=404, detail=f"One or more of the models '{keys}' not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response

View File

@ -14,18 +14,17 @@ from invokeai.app.models.exceptions import CanceledException
from invokeai.backend.model_manager import (
BaseModelType,
DuplicateModelException,
MergeInterpolationMethod,
ModelConfigBase,
ModelInfo,
ModelInstallJob,
ModelLoader,
ModelMerger,
ModelLoad,
ModelSearch,
ModelType,
SubModelType,
UnknownModelException,
)
from invokeai.backend.model_manager.cache import CacheStats
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from .config import InvokeAIAppConfig
@ -291,7 +290,7 @@ class ModelManagerServiceBase(ABC):
class ModelManagerService(ModelManagerServiceBase):
"""Responsible for managing models on disk and in memory."""
_loader: ModelLoader = Field(description="InvokeAIAppConfig object for the current process")
_loader: ModelLoad = Field(description="InvokeAIAppConfig object for the current process")
_event_bus: "EventServiceBase" = Field(description="an event bus to send install events to", default=None)
def __init__(self, config: InvokeAIAppConfig, event_bus: Optional["EventServiceBase"] = None):
@ -304,7 +303,7 @@ class ModelManagerService(ModelManagerServiceBase):
"""
self._event_bus = event_bus
handlers = [self._event_bus.emit_model_event] if self._event_bus else None
self._loader = ModelLoader(config, event_handlers=handlers)
self._loader = ModelLoad(config, event_handlers=handlers)
def get_model(
self,
@ -500,7 +499,7 @@ class ModelManagerService(ModelManagerServiceBase):
model_keys: List[str] = Field(
default=None, min_items=2, max_items=3, description="List of model keys to merge"
),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
merged_model_name: Optional[str] = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
@ -514,8 +513,12 @@ class ModelManagerService(ModelManagerServiceBase):
:param interp: Interpolation method. None (default)
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
"""
merger = ModelMerger(self._loader)
merger = ModelMerger(self._loader.store)
try:
if not merged_model_name:
merged_model_name = "+".join([self._loader.store.get_model(x).name for x in model_keys])
raise Exception("not implemented")
self.logger.error("ModelMerger needs to be rewritten.")
result = merger.merge_diffusion_models_and_save(
model_keys=model_keys,

View File

@ -7,7 +7,7 @@ from .model_manager import ( # noqa F401
InvalidModelException,
ModelConfigStore,
ModelInstall,
ModelLoader,
ModelLoad,
ModelType,
ModelVariantType,
SchedulerPredictionType,

View File

@ -14,11 +14,10 @@ from .config import ( # noqa F401
SubModelType,
)
from .install import ModelInstall, ModelInstallJob # noqa F401
from .loader import ModelInfo, ModelLoader # noqa F401
from .loader import ModelInfo, ModelLoad # noqa F401
from .lora import ModelPatcher, ONNXModelPatcher
from .merge import MergeInterpolationMethod, ModelMerger
from .models import OPENAPI_MODEL_CONFIGS, read_checkpoint_meta # noqa F401
from .probe import InvalidModelException, ModelProbe # noqa F401
from .probe import InvalidModelException, ModelProbeInfo # noqa F401
from .search import ModelSearch # noqa F401
from .storage import ( # noqa F401
DuplicateModelException,

View File

@ -128,6 +128,12 @@ class DownloadQueueBase(ABC):
:param variant: Variant to download, such as "fp16" (repo_ids only).
:param event_handlers: Optional callables that will be called whenever job status changes.
:returns the job: job.id will be a non-negative value after execution
Known variants currently are:
1. onnx
2. openvino
3. fp16
4. None (usually returns fp32 model)
"""
pass

View File

@ -118,7 +118,9 @@ class DownloadQueue(DownloadQueueBase):
access_token: Optional[str] = None,
event_handlers: Optional[List[DownloadEventHandler]] = None,
) -> DownloadJobBase:
"""Create a download job and return its ID."""
"""
Create a download job and return its ID.
"""
kwargs = dict()
if Path(source).exists():
@ -503,8 +505,8 @@ class DownloadQueue(DownloadQueueBase):
repo_id = job.source
variant = job.variant
urls_to_download, metadata = self._get_repo_info(repo_id, variant)
if job.destination.stem != Path(repo_id).stem:
job.destination = job.destination / Path(repo_id).stem
if job.destination.name != Path(repo_id).name:
job.destination = job.destination / Path(repo_id).name
job.metadata = metadata
bytes_downloaded = dict()
job.total_bytes = 0
@ -535,7 +537,15 @@ class DownloadQueue(DownloadQueueBase):
repo_id: str,
variant: Optional[str] = None,
) -> Tuple[List[Tuple[AnyHttpUrl, Path, Path]], ModelSourceMetadata]:
"""Given a repo_id and an optional variant, return list of URLs to download to get the model."""
"""
Given a repo_id and an optional variant, return list of URLs to download to get the model.
Known variants currently are:
1. onnx
2. openvino
3. fp16
4. None (usually returns fp32 model)
"""
model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True)
sibs = model_info.siblings
paths = [x.rfilename for x in sibs]
@ -564,7 +574,19 @@ class DownloadQueue(DownloadQueueBase):
basenames = dict()
for p in paths:
path = Path(p)
if path.suffix in [".bin", ".safetensors", ".pt"]:
if path.suffix == ".onnx":
if variant == "onnx":
result.add(path)
elif path.name.startswith("openvino_model"):
if variant == "openvino":
result.add(path)
elif path.suffix in [".json", ".txt"]:
result.add(path)
elif path.suffix in [".bin", ".safetensors", ".pt"] and variant in ["fp16", None]:
parent = path.parent
suffixes = path.suffixes
if len(suffixes) == 2:
@ -584,10 +606,13 @@ class DownloadQueue(DownloadQueueBase):
basenames[basename] = path
else:
basenames[basename] = path
else:
result.add(path)
continue
for v in basenames.values():
result.add(v)
return result
def _download_path(self, job: DownloadJobBase):

View File

@ -146,13 +146,19 @@ class ModelInstallBase(ABC):
"""Return the download queue used by the installer."""
pass
@property
@abstractmethod
def register_path(self, model_path: Union[Path, str], info: Optional[ModelProbeInfo] = None) -> str:
def store(self) -> ModelConfigStore:
"""Return the storage backend used by the installer."""
pass
@abstractmethod
def register_path(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]]) -> str:
"""
Probe and register the model at model_path.
:param model_path: Filesystem Path to the model.
:param info: Optional ModelProbeInfo object. If not provided, model will be probed.
:param overrides: Dict of attributes that will override probed values.
:returns id: The string ID of the registered model.
"""
pass
@ -201,6 +207,12 @@ class ModelInstallBase(ABC):
The `inplace` flag does not affect the behavior of downloaded
models, which are always moved into the `models` directory.
Variants recognized by HuggingFace currently are:
1. onnx
2. openvino
3. fp16
4. None (usually returns fp32 model)
"""
pass
@ -349,6 +361,11 @@ class ModelInstall(ModelInstallBase):
"""Return the queue."""
return self._download_queue
@property
def store(self) -> ModelConfigStore:
"""Return the storage backend used by the installer."""
return self._store
def register_path(
self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None
) -> str: # noqa D102
@ -360,7 +377,7 @@ class ModelInstall(ModelInstallBase):
key: str = FastModelHash.hash(model_path)
registration_data = dict(
path=model_path.as_posix(),
name=model_path.stem,
name=model_path.name if model_path.is_dir() else model_path.stem,
base_model=info.base_type,
model_type=info.model_type,
model_format=info.format,
@ -581,7 +598,7 @@ class ModelInstall(ModelInstallBase):
This will raise a ValueError unless the model is a checkpoint.
This will raise an UnknownModelException if key is unknown.
"""
from .loader import ModelInfo, ModelLoader # to avoid circular imports
from .loader import ModelInfo, ModelLoad # to avoid circular imports
new_diffusers_path = None
@ -594,7 +611,7 @@ class ModelInstall(ModelInstallBase):
# We are taking advantage of a side effect of get_model() that converts check points
# into cached diffusers directories stored at `path`. It doesn't matter
# what submodel type we request here, so we get the smallest.
loader = ModelLoader(self._config)
loader = ModelLoad(self._config)
submodel = {"submodel_type": SubModelType.Scheduler} if info.model_type == ModelType.Main else {}
converted_model: ModelInfo = loader.get_model(key, **submodel)

View File

@ -42,7 +42,7 @@ class ModelInfo:
self.context.__exit__(*args, **kwargs)
class ModelLoaderBase(ABC):
class ModelLoadBase(ABC):
"""Abstract base class for a model loader which works with the ModelConfigStore backend."""
@abstractmethod
@ -113,8 +113,8 @@ class ModelLoaderBase(ABC):
pass
class ModelLoader(ModelLoaderBase):
"""Implementation of ModelLoaderBase."""
class ModelLoad(ModelLoadBase):
"""Implementation of ModelLoadBase."""
_app_config: InvokeAIAppConfig
_store: ModelConfigStore
@ -130,7 +130,7 @@ class ModelLoader(ModelLoaderBase):
event_handlers: Optional[List[DownloadEventHandler]] = None,
):
"""
Initialize ModelLoader object.
Initialize ModelLoad object.
:param config: The app's InvokeAIAppConfig object.
"""

View File

@ -9,14 +9,16 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
import warnings
from enum import Enum
from pathlib import Path
from typing import List, Optional, Union
from typing import List, Optional
from diffusers import DiffusionPipeline
from diffusers import logging as dlogging
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from . import BaseModelType, ModelConfigBase, ModelLoader, ModelType, ModelVariantType
from . import ModelConfigBase, ModelConfigStore, ModelInstall, ModelType
from .probe import ModelProbe, ModelProbeInfo
class MergeInterpolationMethod(str, Enum):
@ -27,8 +29,18 @@ class MergeInterpolationMethod(str, Enum):
class ModelMerger(object):
def __init__(self, manager: ModelLoader):
self.manager = manager
_store: ModelConfigStore
_config: InvokeAIAppConfig
def __init__(self, store: ModelConfigStore, config: Optional[InvokeAIAppConfig] = None):
"""
Initialize a ModelMerger object.
:param store: Underlying storage manager for the running process.
:param config: InvokeAIAppConfig object (if not provided, default will be selected).
"""
self._store = store
self._config = config or InvokeAIAppConfig.get_config()
def merge_diffusion_models(
self,
@ -70,8 +82,7 @@ class ModelMerger(object):
def merge_diffusion_models_and_save(
self,
model_names: List[str],
base_model: Union[BaseModelType, str],
model_keys: List[str],
merged_model_name: str,
alpha: float = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
@ -93,24 +104,36 @@ class ModelMerger(object):
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
model_paths = list()
config = self.manager.app_config
base_model = BaseModelType(base_model)
model_names = list()
config = self._config
store = self._store
base_models = set()
vae = None
for mod in model_names:
info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main)
assert info, f"model {mod}, base_model {base_model}, is unknown"
assert (
len(model_keys) <= 2 or interp == MergeInterpolationMethod.AddDifference
), "When merging three models, only the 'add_difference' merge method is supported"
for key in model_keys:
info = store.get_model(key)
model_names.append(info.name)
assert (
info["model_format"] == "diffusers"
), f"{mod} is not a diffusers model. It must be optimized before merging"
assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged"
info.model_format == "diffusers"
), f"{info.name} ({info.key}) is not a diffusers model. It must be optimized before merging"
assert (
len(model_names) <= 2 or interp == MergeInterpolationMethod.AddDifference
), "When merging three models, only the 'add_difference' merge method is supported"
info.variant == "normal"
), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged"
# pick up the first model's vae
if mod == model_names[0]:
vae = info.get("vae")
model_paths.extend([(config.root_path / info["path"]).as_posix()])
if key == model_keys[0]:
vae = info.vae
# tally base models used
base_models.add(info.base_model)
model_paths.extend([(config.models_path / info.path).as_posix()])
assert len(base_models) == 1, f"All models to merge must have same base model, but found bases {base_models}"
base_model = base_models.pop()
merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp)
logger.debug(f"interp = {interp}, merge_method={merge_method}")
@ -126,18 +149,11 @@ class ModelMerger(object):
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
# register model and get its unique key
info = ModelProbeInfo(
model_type=ModelType.Main,
base_type=base_model,
format="diffusers",
)
key = self.manager.installer.register_path(
model_path=dump_path,
info=info,
)
installer = ModelInstall(store=self._store, config=self._config)
key = installer.register_path(dump_path)
# update model's config
model_config = self.manager.store.get_model(key)
model_config = self._store.get_model(key)
model_config.update(
dict(
name=merged_model_name,
@ -145,5 +161,5 @@ class ModelMerger(object):
vae=vae,
)
)
self.manager.store.update_model(key, model_config)
self._store.update_model(key, model_config)
return model_config

View File

@ -15,26 +15,28 @@ from typing import Callable, Optional
import safetensors.torch
import torch
from picklescan.scanner import scan_file_path
from pydantic import BaseModel
from .config import BaseModelType, ModelFormat, ModelType, ModelVariantType, SchedulerPredictionType
from .util import SilenceWarnings, lora_token_vector_length, read_checkpoint_meta
from .hash import FastModelHash
from .util import lora_token_vector_length, read_checkpoint_meta
class InvalidModelException(Exception):
"""Raised when an invalid model is encountered."""
@dataclass
class ModelProbeInfo(object):
class ModelProbeInfo(BaseModel):
"""Fields describing a probed model."""
model_type: ModelType
base_type: BaseModelType
format: ModelFormat
variant_type: ModelVariantType = "normal"
prediction_type: SchedulerPredictionType = "v_prediction"
upcast_attention: bool = False
image_size: int = None
hash: str
variant_type: Optional[ModelVariantType] = "normal"
prediction_type: Optional[SchedulerPredictionType] = "v_prediction"
upcast_attention: Optional[bool] = False
image_size: Optional[int] = None
class ModelProbeBase(ABC):
@ -131,6 +133,7 @@ class ModelProbe(ModelProbeBase):
variant_type = probe.get_variant_type()
prediction_type = probe.get_scheduler_prediction_type()
format = probe.get_format()
hash = FastModelHash.hash(model)
model_info = ModelProbeInfo(
model_type=model_type,
@ -142,6 +145,7 @@ class ModelProbe(ModelProbeBase):
and prediction_type == SchedulerPredictionType.VPrediction
),
format=format,
hash=hash,
image_size=1024
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
else 768

View File

@ -3,6 +3,7 @@
"""Little command-line utility for probing a model on disk."""
import argparse
import json
import sys
from pathlib import Path
@ -25,6 +26,6 @@ args = parser.parse_args()
for path in args.model_path:
try:
info = ModelProbe().probe(path, helper)
print(f"{path}: {info}")
print(f"{path}:{json.dumps(info.dict(), sort_keys=True, indent=4)}")
except InvalidModelException as exc:
print(exc)

View File

@ -4,7 +4,7 @@ import pytest
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend import BaseModelType, ModelConfigStore, ModelType, SubModelType
from invokeai.backend.model_manager import ModelLoader
from invokeai.backend.model_manager import ModelLoad
BASIC_MODEL_NAME = "sdxl-base-1-0"
VAE_OVERRIDE_MODEL_NAME = "sdxl-base-with-custom-vae-1-0"
@ -12,18 +12,18 @@ VAE_NULL_OVERRIDE_MODEL_NAME = "sdxl-base-with-empty-vae-1-0"
@pytest.fixture
def model_manager(datadir) -> ModelLoader:
def model_manager(datadir) -> ModelLoad:
config = InvokeAIAppConfig(root=datadir, conf_path="configs/relative_sub.models.yaml")
return ModelLoader(config=config)
return ModelLoad(config=config)
def test_get_model_names(model_manager: ModelLoader):
def test_get_model_names(model_manager: ModelLoad):
store = model_manager.store
names = [x.name for x in store.all_models()]
assert names[:2] == [BASIC_MODEL_NAME, VAE_OVERRIDE_MODEL_NAME]
def test_get_model_path_for_diffusers(model_manager: ModelLoader, datadir: Path):
def test_get_model_path_for_diffusers(model_manager: ModelLoad, datadir: Path):
models = model_manager.store.search_by_name(model_name=BASIC_MODEL_NAME)
assert len(models) == 1
model_config = models[0]
@ -33,7 +33,7 @@ def test_get_model_path_for_diffusers(model_manager: ModelLoader, datadir: Path)
assert not is_override
def test_get_model_path_for_overridden_vae(model_manager: ModelLoader, datadir: Path):
def test_get_model_path_for_overridden_vae(model_manager: ModelLoad, datadir: Path):
models = model_manager.store.search_by_name(model_name=VAE_OVERRIDE_MODEL_NAME)
assert len(models) == 1
model_config = models[0]
@ -43,7 +43,7 @@ def test_get_model_path_for_overridden_vae(model_manager: ModelLoader, datadir:
assert is_override
def test_get_model_path_for_null_overridden_vae(model_manager: ModelLoader, datadir: Path):
def test_get_model_path_for_null_overridden_vae(model_manager: ModelLoad, datadir: Path):
model_config = model_manager.store.search_by_name(model_name=VAE_NULL_OVERRIDE_MODEL_NAME)[0]
vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae)
assert not is_override