mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
all methods in router API now tested and working
This commit is contained in:
@ -2,7 +2,6 @@
|
||||
|
||||
|
||||
import pathlib
|
||||
import traceback
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
@ -13,12 +12,13 @@ from starlette.exceptions import HTTPException
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager import (
|
||||
OPENAPI_MODEL_CONFIGS,
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
MergeInterpolationMethod,
|
||||
ModelConfigBase,
|
||||
SchedulerPredictionType,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
@ -225,15 +225,13 @@ async def convert_model(
|
||||
),
|
||||
) -> ConvertModelResponse:
|
||||
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
info = ApiDependencies.invoker.services.model_manager.model_info(key)
|
||||
try:
|
||||
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
|
||||
ApiDependencies.invoker.services.model_manager.convert_model(key, convert_dest_directory=dest)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.model_info(key).dict()
|
||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
||||
raise HTTPException(status_code=404, detail=f"Model '{key}' not found: {str(e)}")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
@ -252,6 +250,7 @@ async def convert_model(
|
||||
async def search_for_models(
|
||||
search_path: pathlib.Path = Query(description="Directory path to search for models"),
|
||||
) -> List[pathlib.Path]:
|
||||
"""Search for all models in a server-local path."""
|
||||
if not search_path.is_dir():
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory"
|
||||
@ -283,27 +282,31 @@ async def list_ckpt_configs() -> List[pathlib.Path]:
|
||||
response_model=bool,
|
||||
)
|
||||
async def sync_to_config() -> bool:
|
||||
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
|
||||
in-memory data structures with disk data structures."""
|
||||
"""
|
||||
Synchronize model in-memory data structures with disk.
|
||||
|
||||
Call after making changes to models.yaml, autoimport directories
|
||||
or models directory.
|
||||
"""
|
||||
ApiDependencies.invoker.services.model_manager.sync_to_config()
|
||||
return True
|
||||
|
||||
|
||||
@models_router.put(
|
||||
"/merge/{base_model}",
|
||||
"/merge",
|
||||
operation_id="merge_models",
|
||||
responses={
|
||||
200: {"description": "Model converted successfully"},
|
||||
400: {"description": "Incompatible models"},
|
||||
404: {"description": "One or more models not found"},
|
||||
409: {"description": "An identical merged model is already installed"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=MergeModelResponse,
|
||||
)
|
||||
async def merge_models(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model"),
|
||||
keys: List[str] = Body(description="model name", min_items=2, max_items=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
||||
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
||||
force: Optional[bool] = Body(
|
||||
@ -314,28 +317,24 @@ async def merge_models(
|
||||
default=None,
|
||||
),
|
||||
) -> MergeModelResponse:
|
||||
"""Convert a checkpoint model into a diffusers model"""
|
||||
"""Merge the indicated diffusers model."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
result = ApiDependencies.invoker.services.model_manager.merge_models(
|
||||
model_names,
|
||||
base_model,
|
||||
merged_model_name=merged_model_name or "+".join(model_names),
|
||||
result: ModelConfigBase = ApiDependencies.invoker.services.model_manager.merge_models(
|
||||
model_keys=keys,
|
||||
merged_model_name=merged_model_name,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory=dest,
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
result.name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Main,
|
||||
)
|
||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||
response = parse_obj_as(ConvertModelResponse, result.dict())
|
||||
except DuplicateModelException as e:
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except UnknownModelException:
|
||||
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
||||
raise HTTPException(status_code=404, detail=f"One or more of the models '{keys}' not found")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
|
@ -14,18 +14,17 @@ from invokeai.app.models.exceptions import CanceledException
|
||||
from invokeai.backend.model_manager import (
|
||||
BaseModelType,
|
||||
DuplicateModelException,
|
||||
MergeInterpolationMethod,
|
||||
ModelConfigBase,
|
||||
ModelInfo,
|
||||
ModelInstallJob,
|
||||
ModelLoader,
|
||||
ModelMerger,
|
||||
ModelLoad,
|
||||
ModelSearch,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.backend.model_manager.cache import CacheStats
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
|
||||
from .config import InvokeAIAppConfig
|
||||
|
||||
@ -291,7 +290,7 @@ class ModelManagerServiceBase(ABC):
|
||||
class ModelManagerService(ModelManagerServiceBase):
|
||||
"""Responsible for managing models on disk and in memory."""
|
||||
|
||||
_loader: ModelLoader = Field(description="InvokeAIAppConfig object for the current process")
|
||||
_loader: ModelLoad = Field(description="InvokeAIAppConfig object for the current process")
|
||||
_event_bus: "EventServiceBase" = Field(description="an event bus to send install events to", default=None)
|
||||
|
||||
def __init__(self, config: InvokeAIAppConfig, event_bus: Optional["EventServiceBase"] = None):
|
||||
@ -304,7 +303,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
"""
|
||||
self._event_bus = event_bus
|
||||
handlers = [self._event_bus.emit_model_event] if self._event_bus else None
|
||||
self._loader = ModelLoader(config, event_handlers=handlers)
|
||||
self._loader = ModelLoad(config, event_handlers=handlers)
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
@ -500,7 +499,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_keys: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model keys to merge"
|
||||
),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
merged_model_name: Optional[str] = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
@ -514,8 +513,12 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
merger = ModelMerger(self._loader)
|
||||
merger = ModelMerger(self._loader.store)
|
||||
try:
|
||||
if not merged_model_name:
|
||||
merged_model_name = "+".join([self._loader.store.get_model(x).name for x in model_keys])
|
||||
raise Exception("not implemented")
|
||||
|
||||
self.logger.error("ModelMerger needs to be rewritten.")
|
||||
result = merger.merge_diffusion_models_and_save(
|
||||
model_keys=model_keys,
|
||||
|
@ -7,7 +7,7 @@ from .model_manager import ( # noqa F401
|
||||
InvalidModelException,
|
||||
ModelConfigStore,
|
||||
ModelInstall,
|
||||
ModelLoader,
|
||||
ModelLoad,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
|
@ -14,11 +14,10 @@ from .config import ( # noqa F401
|
||||
SubModelType,
|
||||
)
|
||||
from .install import ModelInstall, ModelInstallJob # noqa F401
|
||||
from .loader import ModelInfo, ModelLoader # noqa F401
|
||||
from .loader import ModelInfo, ModelLoad # noqa F401
|
||||
from .lora import ModelPatcher, ONNXModelPatcher
|
||||
from .merge import MergeInterpolationMethod, ModelMerger
|
||||
from .models import OPENAPI_MODEL_CONFIGS, read_checkpoint_meta # noqa F401
|
||||
from .probe import InvalidModelException, ModelProbe # noqa F401
|
||||
from .probe import InvalidModelException, ModelProbeInfo # noqa F401
|
||||
from .search import ModelSearch # noqa F401
|
||||
from .storage import ( # noqa F401
|
||||
DuplicateModelException,
|
||||
|
@ -128,6 +128,12 @@ class DownloadQueueBase(ABC):
|
||||
:param variant: Variant to download, such as "fp16" (repo_ids only).
|
||||
:param event_handlers: Optional callables that will be called whenever job status changes.
|
||||
:returns the job: job.id will be a non-negative value after execution
|
||||
|
||||
Known variants currently are:
|
||||
1. onnx
|
||||
2. openvino
|
||||
3. fp16
|
||||
4. None (usually returns fp32 model)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
@ -118,7 +118,9 @@ class DownloadQueue(DownloadQueueBase):
|
||||
access_token: Optional[str] = None,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
) -> DownloadJobBase:
|
||||
"""Create a download job and return its ID."""
|
||||
"""
|
||||
Create a download job and return its ID.
|
||||
"""
|
||||
kwargs = dict()
|
||||
|
||||
if Path(source).exists():
|
||||
@ -503,8 +505,8 @@ class DownloadQueue(DownloadQueueBase):
|
||||
repo_id = job.source
|
||||
variant = job.variant
|
||||
urls_to_download, metadata = self._get_repo_info(repo_id, variant)
|
||||
if job.destination.stem != Path(repo_id).stem:
|
||||
job.destination = job.destination / Path(repo_id).stem
|
||||
if job.destination.name != Path(repo_id).name:
|
||||
job.destination = job.destination / Path(repo_id).name
|
||||
job.metadata = metadata
|
||||
bytes_downloaded = dict()
|
||||
job.total_bytes = 0
|
||||
@ -535,7 +537,15 @@ class DownloadQueue(DownloadQueueBase):
|
||||
repo_id: str,
|
||||
variant: Optional[str] = None,
|
||||
) -> Tuple[List[Tuple[AnyHttpUrl, Path, Path]], ModelSourceMetadata]:
|
||||
"""Given a repo_id and an optional variant, return list of URLs to download to get the model."""
|
||||
"""
|
||||
Given a repo_id and an optional variant, return list of URLs to download to get the model.
|
||||
|
||||
Known variants currently are:
|
||||
1. onnx
|
||||
2. openvino
|
||||
3. fp16
|
||||
4. None (usually returns fp32 model)
|
||||
"""
|
||||
model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True)
|
||||
sibs = model_info.siblings
|
||||
paths = [x.rfilename for x in sibs]
|
||||
@ -564,7 +574,19 @@ class DownloadQueue(DownloadQueueBase):
|
||||
basenames = dict()
|
||||
for p in paths:
|
||||
path = Path(p)
|
||||
if path.suffix in [".bin", ".safetensors", ".pt"]:
|
||||
|
||||
if path.suffix == ".onnx":
|
||||
if variant == "onnx":
|
||||
result.add(path)
|
||||
|
||||
elif path.name.startswith("openvino_model"):
|
||||
if variant == "openvino":
|
||||
result.add(path)
|
||||
|
||||
elif path.suffix in [".json", ".txt"]:
|
||||
result.add(path)
|
||||
|
||||
elif path.suffix in [".bin", ".safetensors", ".pt"] and variant in ["fp16", None]:
|
||||
parent = path.parent
|
||||
suffixes = path.suffixes
|
||||
if len(suffixes) == 2:
|
||||
@ -584,10 +606,13 @@ class DownloadQueue(DownloadQueueBase):
|
||||
basenames[basename] = path
|
||||
else:
|
||||
basenames[basename] = path
|
||||
|
||||
else:
|
||||
result.add(path)
|
||||
continue
|
||||
|
||||
for v in basenames.values():
|
||||
result.add(v)
|
||||
|
||||
return result
|
||||
|
||||
def _download_path(self, job: DownloadJobBase):
|
||||
|
@ -146,13 +146,19 @@ class ModelInstallBase(ABC):
|
||||
"""Return the download queue used by the installer."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def register_path(self, model_path: Union[Path, str], info: Optional[ModelProbeInfo] = None) -> str:
|
||||
def store(self) -> ModelConfigStore:
|
||||
"""Return the storage backend used by the installer."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def register_path(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Probe and register the model at model_path.
|
||||
|
||||
:param model_path: Filesystem Path to the model.
|
||||
:param info: Optional ModelProbeInfo object. If not provided, model will be probed.
|
||||
:param overrides: Dict of attributes that will override probed values.
|
||||
:returns id: The string ID of the registered model.
|
||||
"""
|
||||
pass
|
||||
@ -201,6 +207,12 @@ class ModelInstallBase(ABC):
|
||||
|
||||
The `inplace` flag does not affect the behavior of downloaded
|
||||
models, which are always moved into the `models` directory.
|
||||
|
||||
Variants recognized by HuggingFace currently are:
|
||||
1. onnx
|
||||
2. openvino
|
||||
3. fp16
|
||||
4. None (usually returns fp32 model)
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -349,6 +361,11 @@ class ModelInstall(ModelInstallBase):
|
||||
"""Return the queue."""
|
||||
return self._download_queue
|
||||
|
||||
@property
|
||||
def store(self) -> ModelConfigStore:
|
||||
"""Return the storage backend used by the installer."""
|
||||
return self._store
|
||||
|
||||
def register_path(
|
||||
self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None
|
||||
) -> str: # noqa D102
|
||||
@ -360,7 +377,7 @@ class ModelInstall(ModelInstallBase):
|
||||
key: str = FastModelHash.hash(model_path)
|
||||
registration_data = dict(
|
||||
path=model_path.as_posix(),
|
||||
name=model_path.stem,
|
||||
name=model_path.name if model_path.is_dir() else model_path.stem,
|
||||
base_model=info.base_type,
|
||||
model_type=info.model_type,
|
||||
model_format=info.format,
|
||||
@ -581,7 +598,7 @@ class ModelInstall(ModelInstallBase):
|
||||
This will raise a ValueError unless the model is a checkpoint.
|
||||
This will raise an UnknownModelException if key is unknown.
|
||||
"""
|
||||
from .loader import ModelInfo, ModelLoader # to avoid circular imports
|
||||
from .loader import ModelInfo, ModelLoad # to avoid circular imports
|
||||
|
||||
new_diffusers_path = None
|
||||
|
||||
@ -594,7 +611,7 @@ class ModelInstall(ModelInstallBase):
|
||||
# We are taking advantage of a side effect of get_model() that converts check points
|
||||
# into cached diffusers directories stored at `path`. It doesn't matter
|
||||
# what submodel type we request here, so we get the smallest.
|
||||
loader = ModelLoader(self._config)
|
||||
loader = ModelLoad(self._config)
|
||||
submodel = {"submodel_type": SubModelType.Scheduler} if info.model_type == ModelType.Main else {}
|
||||
converted_model: ModelInfo = loader.get_model(key, **submodel)
|
||||
|
||||
|
@ -42,7 +42,7 @@ class ModelInfo:
|
||||
self.context.__exit__(*args, **kwargs)
|
||||
|
||||
|
||||
class ModelLoaderBase(ABC):
|
||||
class ModelLoadBase(ABC):
|
||||
"""Abstract base class for a model loader which works with the ModelConfigStore backend."""
|
||||
|
||||
@abstractmethod
|
||||
@ -113,8 +113,8 @@ class ModelLoaderBase(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class ModelLoader(ModelLoaderBase):
|
||||
"""Implementation of ModelLoaderBase."""
|
||||
class ModelLoad(ModelLoadBase):
|
||||
"""Implementation of ModelLoadBase."""
|
||||
|
||||
_app_config: InvokeAIAppConfig
|
||||
_store: ModelConfigStore
|
||||
@ -130,7 +130,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize ModelLoader object.
|
||||
Initialize ModelLoad object.
|
||||
|
||||
:param config: The app's InvokeAIAppConfig object.
|
||||
"""
|
||||
|
@ -9,14 +9,16 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
from . import BaseModelType, ModelConfigBase, ModelLoader, ModelType, ModelVariantType
|
||||
from . import ModelConfigBase, ModelConfigStore, ModelInstall, ModelType
|
||||
from .probe import ModelProbe, ModelProbeInfo
|
||||
|
||||
|
||||
class MergeInterpolationMethod(str, Enum):
|
||||
@ -27,8 +29,18 @@ class MergeInterpolationMethod(str, Enum):
|
||||
|
||||
|
||||
class ModelMerger(object):
|
||||
def __init__(self, manager: ModelLoader):
|
||||
self.manager = manager
|
||||
_store: ModelConfigStore
|
||||
_config: InvokeAIAppConfig
|
||||
|
||||
def __init__(self, store: ModelConfigStore, config: Optional[InvokeAIAppConfig] = None):
|
||||
"""
|
||||
Initialize a ModelMerger object.
|
||||
|
||||
:param store: Underlying storage manager for the running process.
|
||||
:param config: InvokeAIAppConfig object (if not provided, default will be selected).
|
||||
"""
|
||||
self._store = store
|
||||
self._config = config or InvokeAIAppConfig.get_config()
|
||||
|
||||
def merge_diffusion_models(
|
||||
self,
|
||||
@ -70,8 +82,7 @@ class ModelMerger(object):
|
||||
|
||||
def merge_diffusion_models_and_save(
|
||||
self,
|
||||
model_names: List[str],
|
||||
base_model: Union[BaseModelType, str],
|
||||
model_keys: List[str],
|
||||
merged_model_name: str,
|
||||
alpha: float = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
@ -93,24 +104,36 @@ class ModelMerger(object):
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
model_paths = list()
|
||||
config = self.manager.app_config
|
||||
base_model = BaseModelType(base_model)
|
||||
model_names = list()
|
||||
config = self._config
|
||||
store = self._store
|
||||
base_models = set()
|
||||
vae = None
|
||||
|
||||
for mod in model_names:
|
||||
info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main)
|
||||
assert info, f"model {mod}, base_model {base_model}, is unknown"
|
||||
assert (
|
||||
len(model_keys) <= 2 or interp == MergeInterpolationMethod.AddDifference
|
||||
), "When merging three models, only the 'add_difference' merge method is supported"
|
||||
|
||||
for key in model_keys:
|
||||
info = store.get_model(key)
|
||||
model_names.append(info.name)
|
||||
assert (
|
||||
info["model_format"] == "diffusers"
|
||||
), f"{mod} is not a diffusers model. It must be optimized before merging"
|
||||
assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged"
|
||||
info.model_format == "diffusers"
|
||||
), f"{info.name} ({info.key}) is not a diffusers model. It must be optimized before merging"
|
||||
assert (
|
||||
len(model_names) <= 2 or interp == MergeInterpolationMethod.AddDifference
|
||||
), "When merging three models, only the 'add_difference' merge method is supported"
|
||||
info.variant == "normal"
|
||||
), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged"
|
||||
|
||||
# pick up the first model's vae
|
||||
if mod == model_names[0]:
|
||||
vae = info.get("vae")
|
||||
model_paths.extend([(config.root_path / info["path"]).as_posix()])
|
||||
if key == model_keys[0]:
|
||||
vae = info.vae
|
||||
|
||||
# tally base models used
|
||||
base_models.add(info.base_model)
|
||||
model_paths.extend([(config.models_path / info.path).as_posix()])
|
||||
|
||||
assert len(base_models) == 1, f"All models to merge must have same base model, but found bases {base_models}"
|
||||
base_model = base_models.pop()
|
||||
|
||||
merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp)
|
||||
logger.debug(f"interp = {interp}, merge_method={merge_method}")
|
||||
@ -126,18 +149,11 @@ class ModelMerger(object):
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||
|
||||
# register model and get its unique key
|
||||
info = ModelProbeInfo(
|
||||
model_type=ModelType.Main,
|
||||
base_type=base_model,
|
||||
format="diffusers",
|
||||
)
|
||||
key = self.manager.installer.register_path(
|
||||
model_path=dump_path,
|
||||
info=info,
|
||||
)
|
||||
installer = ModelInstall(store=self._store, config=self._config)
|
||||
key = installer.register_path(dump_path)
|
||||
|
||||
# update model's config
|
||||
model_config = self.manager.store.get_model(key)
|
||||
model_config = self._store.get_model(key)
|
||||
model_config.update(
|
||||
dict(
|
||||
name=merged_model_name,
|
||||
@ -145,5 +161,5 @@ class ModelMerger(object):
|
||||
vae=vae,
|
||||
)
|
||||
)
|
||||
self.manager.store.update_model(key, model_config)
|
||||
self._store.update_model(key, model_config)
|
||||
return model_config
|
||||
|
@ -15,26 +15,28 @@ from typing import Callable, Optional
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import BaseModelType, ModelFormat, ModelType, ModelVariantType, SchedulerPredictionType
|
||||
from .util import SilenceWarnings, lora_token_vector_length, read_checkpoint_meta
|
||||
from .hash import FastModelHash
|
||||
from .util import lora_token_vector_length, read_checkpoint_meta
|
||||
|
||||
|
||||
class InvalidModelException(Exception):
|
||||
"""Raised when an invalid model is encountered."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelProbeInfo(object):
|
||||
class ModelProbeInfo(BaseModel):
|
||||
"""Fields describing a probed model."""
|
||||
|
||||
model_type: ModelType
|
||||
base_type: BaseModelType
|
||||
format: ModelFormat
|
||||
variant_type: ModelVariantType = "normal"
|
||||
prediction_type: SchedulerPredictionType = "v_prediction"
|
||||
upcast_attention: bool = False
|
||||
image_size: int = None
|
||||
hash: str
|
||||
variant_type: Optional[ModelVariantType] = "normal"
|
||||
prediction_type: Optional[SchedulerPredictionType] = "v_prediction"
|
||||
upcast_attention: Optional[bool] = False
|
||||
image_size: Optional[int] = None
|
||||
|
||||
|
||||
class ModelProbeBase(ABC):
|
||||
@ -131,6 +133,7 @@ class ModelProbe(ModelProbeBase):
|
||||
variant_type = probe.get_variant_type()
|
||||
prediction_type = probe.get_scheduler_prediction_type()
|
||||
format = probe.get_format()
|
||||
hash = FastModelHash.hash(model)
|
||||
|
||||
model_info = ModelProbeInfo(
|
||||
model_type=model_type,
|
||||
@ -142,6 +145,7 @@ class ModelProbe(ModelProbeBase):
|
||||
and prediction_type == SchedulerPredictionType.VPrediction
|
||||
),
|
||||
format=format,
|
||||
hash=hash,
|
||||
image_size=1024
|
||||
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
|
||||
else 768
|
||||
|
@ -3,6 +3,7 @@
|
||||
"""Little command-line utility for probing a model on disk."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
@ -25,6 +26,6 @@ args = parser.parse_args()
|
||||
for path in args.model_path:
|
||||
try:
|
||||
info = ModelProbe().probe(path, helper)
|
||||
print(f"{path}: {info}")
|
||||
print(f"{path}:{json.dumps(info.dict(), sort_keys=True, indent=4)}")
|
||||
except InvalidModelException as exc:
|
||||
print(exc)
|
||||
|
@ -4,7 +4,7 @@ import pytest
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend import BaseModelType, ModelConfigStore, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager import ModelLoader
|
||||
from invokeai.backend.model_manager import ModelLoad
|
||||
|
||||
BASIC_MODEL_NAME = "sdxl-base-1-0"
|
||||
VAE_OVERRIDE_MODEL_NAME = "sdxl-base-with-custom-vae-1-0"
|
||||
@ -12,18 +12,18 @@ VAE_NULL_OVERRIDE_MODEL_NAME = "sdxl-base-with-empty-vae-1-0"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_manager(datadir) -> ModelLoader:
|
||||
def model_manager(datadir) -> ModelLoad:
|
||||
config = InvokeAIAppConfig(root=datadir, conf_path="configs/relative_sub.models.yaml")
|
||||
return ModelLoader(config=config)
|
||||
return ModelLoad(config=config)
|
||||
|
||||
|
||||
def test_get_model_names(model_manager: ModelLoader):
|
||||
def test_get_model_names(model_manager: ModelLoad):
|
||||
store = model_manager.store
|
||||
names = [x.name for x in store.all_models()]
|
||||
assert names[:2] == [BASIC_MODEL_NAME, VAE_OVERRIDE_MODEL_NAME]
|
||||
|
||||
|
||||
def test_get_model_path_for_diffusers(model_manager: ModelLoader, datadir: Path):
|
||||
def test_get_model_path_for_diffusers(model_manager: ModelLoad, datadir: Path):
|
||||
models = model_manager.store.search_by_name(model_name=BASIC_MODEL_NAME)
|
||||
assert len(models) == 1
|
||||
model_config = models[0]
|
||||
@ -33,7 +33,7 @@ def test_get_model_path_for_diffusers(model_manager: ModelLoader, datadir: Path)
|
||||
assert not is_override
|
||||
|
||||
|
||||
def test_get_model_path_for_overridden_vae(model_manager: ModelLoader, datadir: Path):
|
||||
def test_get_model_path_for_overridden_vae(model_manager: ModelLoad, datadir: Path):
|
||||
models = model_manager.store.search_by_name(model_name=VAE_OVERRIDE_MODEL_NAME)
|
||||
assert len(models) == 1
|
||||
model_config = models[0]
|
||||
@ -43,7 +43,7 @@ def test_get_model_path_for_overridden_vae(model_manager: ModelLoader, datadir:
|
||||
assert is_override
|
||||
|
||||
|
||||
def test_get_model_path_for_null_overridden_vae(model_manager: ModelLoader, datadir: Path):
|
||||
def test_get_model_path_for_null_overridden_vae(model_manager: ModelLoad, datadir: Path):
|
||||
model_config = model_manager.store.search_by_name(model_name=VAE_NULL_OVERRIDE_MODEL_NAME)[0]
|
||||
vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae)
|
||||
assert not is_override
|
||||
|
Reference in New Issue
Block a user