Revert "[MM2] Use typed ModelRecordChanges for model_install() rather than un…"

This reverts commit 633bbb4e85.
This commit is contained in:
psychedelicious 2024-07-24 08:00:09 +10:00 committed by GitHub
parent 633bbb4e85
commit 8d22f5741d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 13536 additions and 256 deletions

View File

@ -6,7 +6,7 @@ import pathlib
import traceback import traceback
from copy import deepcopy from copy import deepcopy
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import List, Optional, Type from typing import Any, Dict, List, Optional, Type
from fastapi import Body, Path, Query, Response, UploadFile from fastapi import Body, Path, Query, Response, UploadFile
from fastapi.responses import FileResponse, HTMLResponse from fastapi.responses import FileResponse, HTMLResponse
@ -430,11 +430,13 @@ async def delete_model_image(
async def install_model( async def install_model(
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"), source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False), inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
access_token: Optional[str] = Query(description="access token for the remote resource", default=None), # TODO(MM2): Can we type this?
config: ModelRecordChanges = Body( config: Optional[Dict[str, Any]] = Body(
description="Object containing fields that override auto-probed values in the model config record, such as name, description and prediction_type ", description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
default=None,
example={"name": "string", "description": "string"}, example={"name": "string", "description": "string"},
), ),
access_token: Optional[str] = None,
) -> ModelInstallJob: ) -> ModelInstallJob:
"""Install a model using a string identifier. """Install a model using a string identifier.
@ -449,9 +451,8 @@ async def install_model(
- model/name:fp16:path/to/model.safetensors - model/name:fp16:path/to/model.safetensors
- model/name::path/to/model.safetensors - model/name::path/to/model.safetensors
`config` is a ModelRecordChanges object. Fields in this object will override `config` is an optional dict containing model configuration values that will override
the ones that are probed automatically. Pass an empty object to accept the ones that are probed automatically.
all the defaults.
`access_token` is an optional access token for use with Urls that require `access_token` is an optional access token for use with Urls that require
authentication. authentication.
@ -736,7 +737,7 @@ async def convert_model(
# write the converted file to the convert path # write the converted file to the convert path
raw_model = converted_model.model raw_model = converted_model.model
assert hasattr(raw_model, "save_pretrained") assert hasattr(raw_model, "save_pretrained")
raw_model.save_pretrained(convert_path) # type: ignore raw_model.save_pretrained(convert_path)
assert convert_path.exists() assert convert_path.exists()
# temporarily rename the original safetensors file so that there is no naming conflict # temporarily rename the original safetensors file so that there is no naming conflict
@ -749,12 +750,12 @@ async def convert_model(
try: try:
new_key = installer.install_path( new_key = installer.install_path(
convert_path, convert_path,
config=ModelRecordChanges( config={
name=original_name, "name": original_name,
description=model_config.description, "description": model_config.description,
hash=model_config.hash, "hash": model_config.hash,
source=model_config.source, "source": model_config.source,
), },
) )
except Exception as e: except Exception as e:
logger.error(str(e)) logger.error(str(e))

View File

@ -3,7 +3,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import List, Optional, Union from typing import Any, Dict, List, Optional, Union
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
@ -12,7 +12,7 @@ from invokeai.app.services.download import DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
from invokeai.app.services.model_records import ModelRecordChanges, ModelRecordServiceBase from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig from invokeai.backend.model_manager import AnyModelConfig
@ -64,7 +64,7 @@ class ModelInstallServiceBase(ABC):
def register_path( def register_path(
self, self,
model_path: Union[Path, str], model_path: Union[Path, str],
config: Optional[ModelRecordChanges] = None, config: Optional[Dict[str, Any]] = None,
) -> str: ) -> str:
""" """
Probe and register the model at model_path. Probe and register the model at model_path.
@ -72,7 +72,7 @@ class ModelInstallServiceBase(ABC):
This keeps the model in its current location. This keeps the model in its current location.
:param model_path: Filesystem Path to the model. :param model_path: Filesystem Path to the model.
:param config: ModelRecordChanges object that will override autoassigned model record values. :param config: Dict of attributes that will override autoassigned values.
:returns id: The string ID of the registered model. :returns id: The string ID of the registered model.
""" """
@ -92,7 +92,7 @@ class ModelInstallServiceBase(ABC):
def install_path( def install_path(
self, self,
model_path: Union[Path, str], model_path: Union[Path, str],
config: Optional[ModelRecordChanges] = None, config: Optional[Dict[str, Any]] = None,
) -> str: ) -> str:
""" """
Probe, register and install the model in the models directory. Probe, register and install the model in the models directory.
@ -101,7 +101,7 @@ class ModelInstallServiceBase(ABC):
the models directory handled by InvokeAI. the models directory handled by InvokeAI.
:param model_path: Filesystem Path to the model. :param model_path: Filesystem Path to the model.
:param config: ModelRecordChanges object that will override autoassigned model record values. :param config: Dict of attributes that will override autoassigned values.
:returns id: The string ID of the registered model. :returns id: The string ID of the registered model.
""" """
@ -109,14 +109,14 @@ class ModelInstallServiceBase(ABC):
def heuristic_import( def heuristic_import(
self, self,
source: str, source: str,
config: Optional[ModelRecordChanges] = None, config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None, access_token: Optional[str] = None,
inplace: Optional[bool] = False, inplace: Optional[bool] = False,
) -> ModelInstallJob: ) -> ModelInstallJob:
r"""Install the indicated model using heuristics to interpret user intentions. r"""Install the indicated model using heuristics to interpret user intentions.
:param source: String source :param source: String source
:param config: Optional ModelRecordChanges object. Any fields in this object :param config: Optional dict. Any fields in this dict
will override corresponding autoassigned probe fields in the will override corresponding autoassigned probe fields in the
model's config record as described in `import_model()`. model's config record as described in `import_model()`.
:param access_token: Optional access token for remote sources. :param access_token: Optional access token for remote sources.
@ -147,7 +147,7 @@ class ModelInstallServiceBase(ABC):
def import_model( def import_model(
self, self,
source: ModelSource, source: ModelSource,
config: Optional[ModelRecordChanges] = None, config: Optional[Dict[str, Any]] = None,
) -> ModelInstallJob: ) -> ModelInstallJob:
"""Install the indicated model. """Install the indicated model.

View File

@ -2,14 +2,13 @@ import re
import traceback import traceback
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Literal, Optional, Set, Union from typing import Any, Dict, Literal, Optional, Set, Union
from pydantic import BaseModel, Field, PrivateAttr, field_validator from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated from typing_extensions import Annotated
from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob
from invokeai.app.services.model_records import ModelRecordChanges
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
@ -134,9 +133,8 @@ class ModelInstallJob(BaseModel):
id: int = Field(description="Unique ID for this job") id: int = Field(description="Unique ID for this job")
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process") status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed") error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
config_in: ModelRecordChanges = Field( config_in: Dict[str, Any] = Field(
default_factory=ModelRecordChanges, default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
description="Configuration information (e.g. 'description') to apply to model.",
) )
config_out: Optional[AnyModelConfig] = Field( config_out: Optional[AnyModelConfig] = Field(
default=None, description="After successful installation, this will hold the configuration object." default=None, description="After successful installation, this will hold the configuration object."

View File

@ -163,27 +163,26 @@ class ModelInstallService(ModelInstallServiceBase):
def register_path( def register_path(
self, self,
model_path: Union[Path, str], model_path: Union[Path, str],
config: Optional[ModelRecordChanges] = None, config: Optional[Dict[str, Any]] = None,
) -> str: # noqa D102 ) -> str: # noqa D102
model_path = Path(model_path) model_path = Path(model_path)
config = config or ModelRecordChanges() config = config or {}
if not config.source: if not config.get("source"):
config.source = model_path.resolve().as_posix() config["source"] = model_path.resolve().as_posix()
config.source_type = ModelSourceType.Path config["source_type"] = ModelSourceType.Path
return self._register(model_path, config) return self._register(model_path, config)
def install_path( def install_path(
self, self,
model_path: Union[Path, str], model_path: Union[Path, str],
config: Optional[ModelRecordChanges] = None, config: Optional[Dict[str, Any]] = None,
) -> str: # noqa D102 ) -> str: # noqa D102
model_path = Path(model_path) model_path = Path(model_path)
config = config or ModelRecordChanges() config = config or {}
info: AnyModelConfig = ModelProbe.probe(
Path(model_path), config.model_dump(), hash_algo=self._app_config.hashing_algorithm
) # type: ignore
if preferred_name := config.name: info: AnyModelConfig = ModelProbe.probe(Path(model_path), config, hash_algo=self._app_config.hashing_algorithm)
if preferred_name := config.get("name"):
preferred_name = Path(preferred_name).with_suffix(model_path.suffix) preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
dest_path = ( dest_path = (
@ -205,7 +204,7 @@ class ModelInstallService(ModelInstallServiceBase):
def heuristic_import( def heuristic_import(
self, self,
source: str, source: str,
config: Optional[ModelRecordChanges] = None, config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None, access_token: Optional[str] = None,
inplace: Optional[bool] = False, inplace: Optional[bool] = False,
) -> ModelInstallJob: ) -> ModelInstallJob:
@ -217,7 +216,7 @@ class ModelInstallService(ModelInstallServiceBase):
source_obj.access_token = access_token source_obj.access_token = access_token
return self.import_model(source_obj, config) return self.import_model(source_obj, config)
def import_model(self, source: ModelSource, config: Optional[ModelRecordChanges] = None) -> ModelInstallJob: # noqa D102 def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state] similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state]
if similar_jobs: if similar_jobs:
self._logger.warning(f"There is already an active install job for {source}. Not enqueuing.") self._logger.warning(f"There is already an active install job for {source}. Not enqueuing.")
@ -319,17 +318,16 @@ class ModelInstallService(ModelInstallServiceBase):
model_path = self._app_config.models_path / model_path model_path = self._app_config.models_path / model_path
model_path = model_path.resolve() model_path = model_path.resolve()
config = ModelRecordChanges( config: dict[str, Any] = {}
name=model_name, config["name"] = model_name
description=stanza.get("description"), config["description"] = stanza.get("description")
)
legacy_config_path = stanza.get("config") legacy_config_path = stanza.get("config")
if legacy_config_path: if legacy_config_path:
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir. # In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
legacy_config_path = self._app_config.root_path / legacy_config_path legacy_config_path = self._app_config.root_path / legacy_config_path
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path): if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path) legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
config.config_path = str(legacy_config_path) config["config_path"] = str(legacy_config_path)
try: try:
id = self.register_path(model_path=model_path, config=config) id = self.register_path(model_path=model_path, config=config)
self._logger.info(f"Migrated {model_name} with id {id}") self._logger.info(f"Migrated {model_name} with id {id}")
@ -502,11 +500,11 @@ class ModelInstallService(ModelInstallServiceBase):
job.total_bytes = self._stat_size(job.local_path) job.total_bytes = self._stat_size(job.local_path)
job.bytes = job.total_bytes job.bytes = job.total_bytes
self._signal_job_running(job) self._signal_job_running(job)
job.config_in.source = str(job.source) job.config_in["source"] = str(job.source)
job.config_in.source_type = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__] job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
# enter the metadata, if there is any # enter the metadata, if there is any
if isinstance(job.source_metadata, (HuggingFaceMetadata)): if isinstance(job.source_metadata, (HuggingFaceMetadata)):
job.config_in.source_api_response = job.source_metadata.api_response job.config_in["source_api_response"] = job.source_metadata.api_response
if job.inplace: if job.inplace:
key = self.register_path(job.local_path, job.config_in) key = self.register_path(job.local_path, job.config_in)
@ -641,11 +639,11 @@ class ModelInstallService(ModelInstallServiceBase):
return new_path return new_path
def _register( def _register(
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
) -> str: ) -> str:
config = config or ModelRecordChanges() config = config or {}
info = info or ModelProbe.probe(model_path, config.model_dump(), hash_algo=self._app_config.hashing_algorithm) # type: ignore info = info or ModelProbe.probe(model_path, config, hash_algo=self._app_config.hashing_algorithm)
model_path = model_path.resolve() model_path = model_path.resolve()
@ -676,13 +674,11 @@ class ModelInstallService(ModelInstallServiceBase):
precision = TorchDevice.choose_torch_dtype() precision = TorchDevice.choose_torch_dtype()
return ModelRepoVariant.FP16 if precision == torch.float16 else None return ModelRepoVariant.FP16 if precision == torch.float16 else None
def _import_local_model( def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
self, source: LocalModelSource, config: Optional[ModelRecordChanges] = None
) -> ModelInstallJob:
return ModelInstallJob( return ModelInstallJob(
id=self._next_id(), id=self._next_id(),
source=source, source=source,
config_in=config or ModelRecordChanges(), config_in=config or {},
local_path=Path(source.path), local_path=Path(source.path),
inplace=source.inplace or False, inplace=source.inplace or False,
) )
@ -690,7 +686,7 @@ class ModelInstallService(ModelInstallServiceBase):
def _import_from_hf( def _import_from_hf(
self, self,
source: HFModelSource, source: HFModelSource,
config: Optional[ModelRecordChanges] = None, config: Optional[Dict[str, Any]] = None,
) -> ModelInstallJob: ) -> ModelInstallJob:
# Add user's cached access token to HuggingFace requests # Add user's cached access token to HuggingFace requests
if source.access_token is None: if source.access_token is None:
@ -706,7 +702,7 @@ class ModelInstallService(ModelInstallServiceBase):
def _import_from_url( def _import_from_url(
self, self,
source: URLModelSource, source: URLModelSource,
config: Optional[ModelRecordChanges] = None, config: Optional[Dict[str, Any]],
) -> ModelInstallJob: ) -> ModelInstallJob:
remote_files, metadata = self._remote_files_from_source(source) remote_files, metadata = self._remote_files_from_source(source)
return self._import_remote_model( return self._import_remote_model(
@ -721,7 +717,7 @@ class ModelInstallService(ModelInstallServiceBase):
source: HFModelSource | URLModelSource, source: HFModelSource | URLModelSource,
remote_files: List[RemoteModelFile], remote_files: List[RemoteModelFile],
metadata: Optional[AnyModelRepoMetadata], metadata: Optional[AnyModelRepoMetadata],
config: Optional[ModelRecordChanges], config: Optional[Dict[str, Any]],
) -> ModelInstallJob: ) -> ModelInstallJob:
if len(remote_files) == 0: if len(remote_files) == 0:
raise ValueError(f"{source}: No downloadable files found") raise ValueError(f"{source}: No downloadable files found")
@ -734,7 +730,7 @@ class ModelInstallService(ModelInstallServiceBase):
install_job = ModelInstallJob( install_job = ModelInstallJob(
id=self._next_id(), id=self._next_id(),
source=source, source=source,
config_in=config or ModelRecordChanges(), config_in=config or {},
source_metadata=metadata, source_metadata=metadata,
local_path=destdir, # local path may change once the download has started due to content-disposition handling local_path=destdir, # local path may change once the download has started due to content-disposition handling
bytes=0, bytes=0,

View File

@ -18,7 +18,6 @@ from invokeai.backend.model_manager.config import (
ControlAdapterDefaultSettings, ControlAdapterDefaultSettings,
MainModelDefaultSettings, MainModelDefaultSettings,
ModelFormat, ModelFormat,
ModelSourceType,
ModelType, ModelType,
ModelVariantType, ModelVariantType,
SchedulerPredictionType, SchedulerPredictionType,
@ -67,16 +66,10 @@ class ModelRecordChanges(BaseModelExcludeNull):
"""A set of changes to apply to a model.""" """A set of changes to apply to a model."""
# Changes applicable to all models # Changes applicable to all models
source: Optional[str] = Field(description="original source of the model", default=None)
source_type: Optional[ModelSourceType] = Field(description="type of model source", default=None)
source_api_response: Optional[str] = Field(description="metadata from remote source", default=None)
name: Optional[str] = Field(description="Name of the model.", default=None) name: Optional[str] = Field(description="Name of the model.", default=None)
path: Optional[str] = Field(description="Path to the model.", default=None) path: Optional[str] = Field(description="Path to the model.", default=None)
description: Optional[str] = Field(description="Model description", default=None) description: Optional[str] = Field(description="Model description", default=None)
base: Optional[BaseModelType] = Field(description="The base model.", default=None) base: Optional[BaseModelType] = Field(description="The base model.", default=None)
type: Optional[ModelType] = Field(description="Type of model", default=None)
key: Optional[str] = Field(description="Database ID for this model", default=None)
hash: Optional[str] = Field(description="hash of model file", default=None)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field( default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
description="Default settings for this model", default=None description="Default settings for this model", default=None

View File

@ -354,7 +354,7 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase):
"""Model config for CLIPVision.""" """Model config for CLIPVision."""
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers format: Literal[ModelFormat.Diffusers]
@staticmethod @staticmethod
def get_tag() -> Tag: def get_tag() -> Tag:
@ -365,7 +365,7 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase):
"""Model config for T2I.""" """Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers format: Literal[ModelFormat.Diffusers]
@staticmethod @staticmethod
def get_tag() -> Tag: def get_tag() -> Tag:

View File

@ -155,8 +155,5 @@
"vite-plugin-eslint": "^1.8.1", "vite-plugin-eslint": "^1.8.1",
"vite-tsconfig-paths": "^4.3.2", "vite-tsconfig-paths": "^4.3.2",
"vitest": "^1.6.0" "vitest": "^1.6.0"
},
"engines": {
"pnpm": "8"
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,11 @@
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { type InstallModelArg, useInstallModelMutation } from 'services/api/endpoints/models'; import { useInstallModelMutation } from 'services/api/endpoints/models';
type InstallModelArgWithCallbacks = InstallModelArg & { type InstallModelArg = {
source: string;
inplace?: boolean;
onSuccess?: () => void; onSuccess?: () => void;
onError?: (error: unknown) => void; onError?: (error: unknown) => void;
}; };
@ -13,9 +15,8 @@ export const useInstallModel = () => {
const [_installModel, request] = useInstallModelMutation(); const [_installModel, request] = useInstallModelMutation();
const installModel = useCallback( const installModel = useCallback(
({ source, inplace, config, onSuccess, onError }: InstallModelArgWithCallbacks) => { ({ source, inplace, onSuccess, onError }: InstallModelArg) => {
config ||= {}; _installModel({ source, inplace })
_installModel({ source, inplace, config })
.unwrap() .unwrap()
.then((_) => { .then((_) => {
if (onSuccess) { if (onSuccess) {

View File

@ -12,19 +12,17 @@ type Props = {
export const StarterModelsResultItem = ({ result }: Props) => { export const StarterModelsResultItem = ({ result }: Props) => {
const { t } = useTranslation(); const { t } = useTranslation();
const allSources = useMemo(() => { const allSources = useMemo(() => {
const _allSources = [{ source: result.source, config: { name: result.name, description: result.description } }]; const _allSources = [result.source];
if (result.dependencies) { if (result.dependencies) {
for (const d of result.dependencies) { _allSources.push(...result.dependencies.map((d) => d.source));
_allSources.push({ source: d.source, config: { name: d.name, description: d.description } });
}
} }
return _allSources; return _allSources;
}, [result]); }, [result]);
const [installModel] = useInstallModel(); const [installModel] = useInstallModel();
const onClick = useCallback(() => { const onClick = useCallback(() => {
for (const { config, source } of allSources) { for (const source of allSources) {
installModel({ config, source }); installModel({ source });
} }
}, [allSources, installModel]); }, [allSources, installModel]);

View File

@ -39,10 +39,9 @@ type DeleteModelImageResponse = void;
type ConvertMainModelResponse = type ConvertMainModelResponse =
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json']; paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
export type InstallModelArg = { type InstallModelArg = {
source: paths['/api/v2/models/install']['post']['parameters']['query']['source']; source: paths['/api/v2/models/install']['post']['parameters']['query']['source'];
inplace?: paths['/api/v2/models/install']['post']['parameters']['query']['inplace']; inplace?: paths['/api/v2/models/install']['post']['parameters']['query']['inplace'];
config?: paths['/api/v2/models/install']['post']['requestBody']['content']['application/json'];
}; };
type InstallModelResponse = paths['/api/v2/models/install']['post']['responses']['201']['content']['application/json']; type InstallModelResponse = paths['/api/v2/models/install']['post']['responses']['201']['content']['application/json'];
@ -125,12 +124,11 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [{ type: 'ModelConfig', id: LIST_TAG }], invalidatesTags: [{ type: 'ModelConfig', id: LIST_TAG }],
}), }),
installModel: build.mutation<InstallModelResponse, InstallModelArg>({ installModel: build.mutation<InstallModelResponse, InstallModelArg>({
query: ({ source, inplace = true, config }) => { query: ({ source, inplace = true }) => {
return { return {
url: buildModelsUrl('install'), url: buildModelsUrl('install'),
params: { source, inplace }, params: { source, inplace },
method: 'POST', method: 'POST',
body: config,
}; };
}, },
invalidatesTags: ['ModelInstalls'], invalidatesTags: ['ModelInstalls'],

View File

@ -103,9 +103,8 @@ export type paths = {
* - model/name:fp16:path/to/model.safetensors * - model/name:fp16:path/to/model.safetensors
* - model/name::path/to/model.safetensors * - model/name::path/to/model.safetensors
* *
* `config` is a ModelRecordChanges object. Fields in this object will override * `config` is an optional dict containing model configuration values that will override
* the ones that are probed automatically. Pass an empty object to accept * the ones that are probed automatically.
* all the defaults.
* *
* `access_token` is an optional access token for use with Urls that require * `access_token` is an optional access token for use with Urls that require
* authentication. * authentication.
@ -1590,7 +1589,6 @@ export type components = {
cover_image?: string | null; cover_image?: string | null;
/** /**
* Format * Format
* @default diffusers
* @constant * @constant
* @enum {string} * @enum {string}
*/ */
@ -3173,7 +3171,7 @@ export type components = {
/** /**
* Fp32 * Fp32
* @description Whether or not to use full float32 precision * @description Whether or not to use full float32 precision
* @default true * @default false
*/ */
fp32?: boolean; fp32?: boolean;
/** /**
@ -3256,7 +3254,7 @@ export type components = {
/** /**
* Fp32 * Fp32
* @description Whether or not to use full float32 precision * @description Whether or not to use full float32 precision
* @default true * @default false
*/ */
fp32?: boolean; fp32?: boolean;
/** /**
@ -6577,7 +6575,7 @@ export type components = {
/** /**
* Fp32 * Fp32
* @description Whether or not to use full float32 precision * @description Whether or not to use full float32 precision
* @default true * @default false
*/ */
fp32?: boolean; fp32?: boolean;
/** /**
@ -7306,146 +7304,146 @@ export type components = {
project_id: string | null; project_id: string | null;
}; };
InvocationOutputMap: { InvocationOutputMap: {
integer: components["schemas"]["IntegerOutput"];
heuristic_resize: components["schemas"]["ImageOutput"];
range_of_size: components["schemas"]["IntegerCollectionOutput"];
sdxl_compel_prompt: components["schemas"]["ConditioningOutput"];
midas_depth_image_processor: components["schemas"]["ImageOutput"];
dw_openpose_image_processor: components["schemas"]["ImageOutput"];
color: components["schemas"]["ColorOutput"];
merge_tiles_to_image: components["schemas"]["ImageOutput"];
merge_metadata: components["schemas"]["MetadataOutput"]; merge_metadata: components["schemas"]["MetadataOutput"];
denoise_latents: components["schemas"]["LatentsOutput"]; string_join_three: components["schemas"]["StringOutput"];
model_identifier: components["schemas"]["ModelIdentifierOutput"];
sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"];
img_resize: components["schemas"]["ImageOutput"];
float_to_int: components["schemas"]["IntegerOutput"];
img_scale: components["schemas"]["ImageOutput"];
string_collection: components["schemas"]["StringCollectionOutput"];
compel: components["schemas"]["ConditioningOutput"];
calculate_image_tiles_min_overlap: components["schemas"]["CalculateImageTilesOutput"];
infill_cv2: components["schemas"]["ImageOutput"];
string_join: components["schemas"]["StringOutput"];
lineart_anime_image_processor: components["schemas"]["ImageOutput"];
infill_rgba: components["schemas"]["ImageOutput"];
sdxl_lora_collection_loader: components["schemas"]["SDXLLoRALoaderOutput"];
lblend: components["schemas"]["LatentsOutput"]; lblend: components["schemas"]["LatentsOutput"];
lscale: components["schemas"]["LatentsOutput"]; img_channel_multiply: components["schemas"]["ImageOutput"];
normalbae_image_processor: components["schemas"]["ImageOutput"];
float_collection: components["schemas"]["FloatCollectionOutput"]; float_collection: components["schemas"]["FloatCollectionOutput"];
range: components["schemas"]["IntegerCollectionOutput"]; face_off: components["schemas"]["FaceOffOutput"];
infill_patchmatch: components["schemas"]["ImageOutput"];
mlsd_image_processor: components["schemas"]["ImageOutput"]; mlsd_image_processor: components["schemas"]["ImageOutput"];
string_replace: components["schemas"]["StringOutput"]; color_map_image_processor: components["schemas"]["ImageOutput"];
image_mask_to_tensor: components["schemas"]["MaskOutput"]; img_paste: components["schemas"]["ImageOutput"];
depth_anything_image_processor: components["schemas"]["ImageOutput"]; img_scale: components["schemas"]["ImageOutput"];
infill_lama: components["schemas"]["ImageOutput"];
metadata_item: components["schemas"]["MetadataItemOutput"];
lora_loader: components["schemas"]["LoRALoaderOutput"];
latents_collection: components["schemas"]["LatentsCollectionOutput"];
alpha_mask_to_tensor: components["schemas"]["MaskOutput"];
rand_float: components["schemas"]["FloatOutput"];
noise: components["schemas"]["NoiseOutput"];
face_mask_detection: components["schemas"]["FaceMaskOutput"];
ideal_size: components["schemas"]["IdealSizeOutput"];
lora_collection_loader: components["schemas"]["LoRALoaderOutput"];
zoe_depth_image_processor: components["schemas"]["ImageOutput"]; zoe_depth_image_processor: components["schemas"]["ImageOutput"];
tile_to_properties: components["schemas"]["TileToPropertiesOutput"]; tile_to_properties: components["schemas"]["TileToPropertiesOutput"];
clip_skip: components["schemas"]["CLIPSkipInvocationOutput"];
save_image: components["schemas"]["ImageOutput"];
collect: components["schemas"]["CollectInvocationOutput"];
face_identifier: components["schemas"]["ImageOutput"];
img_blur: components["schemas"]["ImageOutput"]; img_blur: components["schemas"]["ImageOutput"];
img_paste: components["schemas"]["ImageOutput"]; merge_tiles_to_image: components["schemas"]["ImageOutput"];
segment_anything_processor: components["schemas"]["ImageOutput"]; latents_collection: components["schemas"]["LatentsCollectionOutput"];
add: components["schemas"]["IntegerOutput"];
tiled_multi_diffusion_denoise_latents: components["schemas"]["LatentsOutput"];
img_crop: components["schemas"]["ImageOutput"];
conditioning: components["schemas"]["ConditioningOutput"];
esrgan: components["schemas"]["ImageOutput"];
lineart_image_processor: components["schemas"]["ImageOutput"];
mul: components["schemas"]["IntegerOutput"];
img_nsfw: components["schemas"]["ImageOutput"];
img_mul: components["schemas"]["ImageOutput"]; img_mul: components["schemas"]["ImageOutput"];
sdxl_model_loader: components["schemas"]["SDXLModelLoaderOutput"]; image_mask_to_tensor: components["schemas"]["MaskOutput"];
infill_tile: components["schemas"]["ImageOutput"];
float: components["schemas"]["FloatOutput"];
sdxl_lora_loader: components["schemas"]["SDXLLoRALoaderOutput"];
scheduler: components["schemas"]["SchedulerOutput"];
tomask: components["schemas"]["ImageOutput"];
conditioning_collection: components["schemas"]["ConditioningCollectionOutput"]; conditioning_collection: components["schemas"]["ConditioningCollectionOutput"];
img_ilerp: components["schemas"]["ImageOutput"];
integer_math: components["schemas"]["IntegerOutput"];
lora_selector: components["schemas"]["LoRASelectorOutput"];
sub: components["schemas"]["IntegerOutput"];
crop_latents: components["schemas"]["LatentsOutput"];
string_join_three: components["schemas"]["StringOutput"];
cv_inpaint: components["schemas"]["ImageOutput"]; cv_inpaint: components["schemas"]["ImageOutput"];
main_model_loader: components["schemas"]["ModelLoaderOutput"];
hed_image_processor: components["schemas"]["ImageOutput"];
create_gradient_mask: components["schemas"]["GradientMaskOutput"];
create_denoise_mask: components["schemas"]["DenoiseMaskOutput"];
mask_combine: components["schemas"]["ImageOutput"];
img_pad_crop: components["schemas"]["ImageOutput"]; img_pad_crop: components["schemas"]["ImageOutput"];
freeu: components["schemas"]["UNetOutput"];
lresize: components["schemas"]["LatentsOutput"]; lresize: components["schemas"]["LatentsOutput"];
metadata: components["schemas"]["MetadataOutput"]; conditioning: components["schemas"]["ConditioningOutput"];
color_map_image_processor: components["schemas"]["ImageOutput"];
image_collection: components["schemas"]["ImageCollectionOutput"];
l2i: components["schemas"]["ImageOutput"];
show_image: components["schemas"]["ImageOutput"];
t2i_adapter: components["schemas"]["T2IAdapterOutput"];
round_float: components["schemas"]["FloatOutput"];
canvas_paste_back: components["schemas"]["ImageOutput"];
img_hue_adjust: components["schemas"]["ImageOutput"];
sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"];
spandrel_image_to_image: components["schemas"]["ImageOutput"];
step_param_easing: components["schemas"]["FloatCollectionOutput"];
calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"];
color_correct: components["schemas"]["ImageOutput"];
float_range: components["schemas"]["FloatCollectionOutput"];
mediapipe_face_processor: components["schemas"]["ImageOutput"];
prompt_from_file: components["schemas"]["StringCollectionOutput"];
random_range: components["schemas"]["IntegerCollectionOutput"];
invert_tensor_mask: components["schemas"]["MaskOutput"];
img_conv: components["schemas"]["ImageOutput"];
seamless: components["schemas"]["SeamlessModeOutput"];
ip_adapter: components["schemas"]["IPAdapterOutput"];
i2l: components["schemas"]["LatentsOutput"];
integer_collection: components["schemas"]["IntegerCollectionOutput"];
vae_loader: components["schemas"]["VAEOutput"];
leres_image_processor: components["schemas"]["ImageOutput"];
blank_image: components["schemas"]["ImageOutput"];
mask_from_id: components["schemas"]["ImageOutput"];
pair_tile_image: components["schemas"]["PairTileImageOutput"];
dynamic_prompt: components["schemas"]["StringCollectionOutput"]; dynamic_prompt: components["schemas"]["StringCollectionOutput"];
mask_edge: components["schemas"]["ImageOutput"]; tomask: components["schemas"]["ImageOutput"];
img_channel_multiply: components["schemas"]["ImageOutput"]; mul: components["schemas"]["IntegerOutput"];
controlnet: components["schemas"]["ControlOutput"]; seamless: components["schemas"]["SeamlessModeOutput"];
latents: components["schemas"]["LatentsOutput"];
unsharp_mask: components["schemas"]["ImageOutput"];
canny_image_processor: components["schemas"]["ImageOutput"]; canny_image_processor: components["schemas"]["ImageOutput"];
core_metadata: components["schemas"]["MetadataOutput"]; metadata_item: components["schemas"]["MetadataItemOutput"];
div: components["schemas"]["IntegerOutput"]; add: components["schemas"]["IntegerOutput"];
crop_latents: components["schemas"]["LatentsOutput"];
integer_collection: components["schemas"]["IntegerCollectionOutput"];
sdxl_model_loader: components["schemas"]["SDXLModelLoaderOutput"];
string_split: components["schemas"]["String2Output"]; string_split: components["schemas"]["String2Output"];
img_chan: components["schemas"]["ImageOutput"];
img_channel_offset: components["schemas"]["ImageOutput"];
calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"];
rand_int: components["schemas"]["IntegerOutput"];
string_split_neg: components["schemas"]["StringPosNegOutput"];
face_off: components["schemas"]["FaceOffOutput"];
boolean: components["schemas"]["BooleanOutput"];
string: components["schemas"]["StringOutput"];
float_math: components["schemas"]["FloatOutput"];
pidi_image_processor: components["schemas"]["ImageOutput"];
img_watermark: components["schemas"]["ImageOutput"];
content_shuffle_image_processor: components["schemas"]["ImageOutput"];
iterate: components["schemas"]["IterateInvocationOutput"];
img_lerp: components["schemas"]["ImageOutput"];
image: components["schemas"]["ImageOutput"];
rectangle_mask: components["schemas"]["MaskOutput"];
tile_image_processor: components["schemas"]["ImageOutput"]; tile_image_processor: components["schemas"]["ImageOutput"];
infill_cv2: components["schemas"]["ImageOutput"];
tiled_multi_diffusion_denoise_latents: components["schemas"]["LatentsOutput"];
collect: components["schemas"]["CollectInvocationOutput"];
image_collection: components["schemas"]["ImageCollectionOutput"];
save_image: components["schemas"]["ImageOutput"];
controlnet: components["schemas"]["ControlOutput"];
float_math: components["schemas"]["FloatOutput"];
sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"];
i2l: components["schemas"]["LatentsOutput"];
infill_lama: components["schemas"]["ImageOutput"];
sub: components["schemas"]["IntegerOutput"];
div: components["schemas"]["IntegerOutput"];
face_mask_detection: components["schemas"]["FaceMaskOutput"];
esrgan: components["schemas"]["ImageOutput"];
mask_combine: components["schemas"]["ImageOutput"];
ip_adapter: components["schemas"]["IPAdapterOutput"];
blank_image: components["schemas"]["ImageOutput"];
heuristic_resize: components["schemas"]["ImageOutput"];
rand_int: components["schemas"]["IntegerOutput"];
lora_selector: components["schemas"]["LoRASelectorOutput"];
unsharp_mask: components["schemas"]["ImageOutput"];
face_identifier: components["schemas"]["ImageOutput"];
sdxl_compel_prompt: components["schemas"]["ConditioningOutput"];
infill_patchmatch: components["schemas"]["ImageOutput"];
img_nsfw: components["schemas"]["ImageOutput"];
lineart_anime_image_processor: components["schemas"]["ImageOutput"];
compel: components["schemas"]["ConditioningOutput"];
rectangle_mask: components["schemas"]["MaskOutput"];
lora_collection_loader: components["schemas"]["LoRALoaderOutput"];
freeu: components["schemas"]["UNetOutput"];
img_hue_adjust: components["schemas"]["ImageOutput"];
pidi_image_processor: components["schemas"]["ImageOutput"];
content_shuffle_image_processor: components["schemas"]["ImageOutput"];
mediapipe_face_processor: components["schemas"]["ImageOutput"];
string_split_neg: components["schemas"]["StringPosNegOutput"];
img_conv: components["schemas"]["ImageOutput"];
lora_loader: components["schemas"]["LoRALoaderOutput"];
color_correct: components["schemas"]["ImageOutput"];
img_ilerp: components["schemas"]["ImageOutput"];
noise: components["schemas"]["NoiseOutput"];
float_range: components["schemas"]["FloatCollectionOutput"];
dw_openpose_image_processor: components["schemas"]["ImageOutput"];
float_to_int: components["schemas"]["IntegerOutput"];
invert_tensor_mask: components["schemas"]["MaskOutput"];
random_range: components["schemas"]["IntegerCollectionOutput"];
latents: components["schemas"]["LatentsOutput"];
leres_image_processor: components["schemas"]["ImageOutput"];
t2i_adapter: components["schemas"]["T2IAdapterOutput"];
pair_tile_image: components["schemas"]["PairTileImageOutput"];
mask_edge: components["schemas"]["ImageOutput"];
metadata: components["schemas"]["MetadataOutput"];
string_join: components["schemas"]["StringOutput"];
core_metadata: components["schemas"]["MetadataOutput"];
canvas_paste_back: components["schemas"]["ImageOutput"];
sdxl_lora_collection_loader: components["schemas"]["SDXLLoRALoaderOutput"];
img_channel_offset: components["schemas"]["ImageOutput"];
lineart_image_processor: components["schemas"]["ImageOutput"];
midas_depth_image_processor: components["schemas"]["ImageOutput"];
lscale: components["schemas"]["LatentsOutput"];
string: components["schemas"]["StringOutput"];
integer: components["schemas"]["IntegerOutput"];
string_replace: components["schemas"]["StringOutput"];
depth_anything_image_processor: components["schemas"]["ImageOutput"];
main_model_loader: components["schemas"]["ModelLoaderOutput"];
image: components["schemas"]["ImageOutput"];
prompt_from_file: components["schemas"]["StringCollectionOutput"];
sdxl_lora_loader: components["schemas"]["SDXLLoRALoaderOutput"];
mask_from_id: components["schemas"]["ImageOutput"];
normalbae_image_processor: components["schemas"]["ImageOutput"];
infill_rgba: components["schemas"]["ImageOutput"];
step_param_easing: components["schemas"]["FloatCollectionOutput"];
hed_image_processor: components["schemas"]["ImageOutput"];
img_chan: components["schemas"]["ImageOutput"];
float: components["schemas"]["FloatOutput"];
boolean_collection: components["schemas"]["BooleanCollectionOutput"]; boolean_collection: components["schemas"]["BooleanCollectionOutput"];
segment_anything_processor: components["schemas"]["ImageOutput"];
range_of_size: components["schemas"]["IntegerCollectionOutput"];
boolean: components["schemas"]["BooleanOutput"];
iterate: components["schemas"]["IterateInvocationOutput"];
denoise_latents: components["schemas"]["LatentsOutput"];
calculate_image_tiles_min_overlap: components["schemas"]["CalculateImageTilesOutput"];
color: components["schemas"]["ColorOutput"];
calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"];
scheduler: components["schemas"]["SchedulerOutput"];
rand_float: components["schemas"]["FloatOutput"];
create_denoise_mask: components["schemas"]["DenoiseMaskOutput"];
range: components["schemas"]["IntegerCollectionOutput"];
img_watermark: components["schemas"]["ImageOutput"];
spandrel_image_to_image: components["schemas"]["ImageOutput"];
show_image: components["schemas"]["ImageOutput"];
string_collection: components["schemas"]["StringCollectionOutput"];
infill_tile: components["schemas"]["ImageOutput"];
clip_skip: components["schemas"]["CLIPSkipInvocationOutput"];
sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"];
ideal_size: components["schemas"]["IdealSizeOutput"];
img_lerp: components["schemas"]["ImageOutput"];
l2i: components["schemas"]["ImageOutput"];
create_gradient_mask: components["schemas"]["GradientMaskOutput"];
vae_loader: components["schemas"]["VAEOutput"];
calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"];
alpha_mask_to_tensor: components["schemas"]["MaskOutput"];
integer_math: components["schemas"]["IntegerOutput"];
model_identifier: components["schemas"]["ModelIdentifierOutput"];
img_crop: components["schemas"]["ImageOutput"];
img_resize: components["schemas"]["ImageOutput"];
round_float: components["schemas"]["FloatOutput"];
}; };
/** /**
* InvocationStartedEvent * InvocationStartedEvent
@ -7792,7 +7790,7 @@ export type components = {
/** /**
* Fp32 * Fp32
* @description Whether or not to use full float32 precision * @description Whether or not to use full float32 precision
* @default true * @default false
*/ */
fp32?: boolean; fp32?: boolean;
/** /**
@ -9596,8 +9594,11 @@ export type components = {
* @description Information about why the job failed * @description Information about why the job failed
*/ */
error_reason?: string | null; error_reason?: string | null;
/** @description Configuration information (e.g. 'description') to apply to model. */ /**
config_in?: components["schemas"]["ModelRecordChanges"]; * Config In
* @description Configuration information (e.g. 'description') to apply to model.
*/
config_in?: Record<string, never>;
/** /**
* Config Out * Config Out
* @description After successful installation, this will hold the configuration object. * @description After successful installation, this will hold the configuration object.
@ -9749,18 +9750,6 @@ export type components = {
* @description A set of changes to apply to a model. * @description A set of changes to apply to a model.
*/ */
ModelRecordChanges: { ModelRecordChanges: {
/**
* Source
* @description original source of the model
*/
source?: string | null;
/** @description type of model source */
source_type?: components["schemas"]["ModelSourceType"] | null;
/**
* Source Api Response
* @description metadata from remote source
*/
source_api_response?: string | null;
/** /**
* Name * Name
* @description Name of the model. * @description Name of the model.
@ -9778,18 +9767,6 @@ export type components = {
description?: string | null; description?: string | null;
/** @description The base model. */ /** @description The base model. */
base?: components["schemas"]["BaseModelType"] | null; base?: components["schemas"]["BaseModelType"] | null;
/** @description Type of model */
type?: components["schemas"]["ModelType"] | null;
/**
* Key
* @description Database ID for this model
*/
key?: string | null;
/**
* Hash
* @description hash of model file
*/
hash?: string | null;
/** /**
* Trigger Phrases * Trigger Phrases
* @description Set of trigger phrases for this model * @description Set of trigger phrases for this model
@ -12633,7 +12610,6 @@ export type components = {
cover_image?: string | null; cover_image?: string | null;
/** /**
* Format * Format
* @default diffusers
* @constant * @constant
* @enum {string} * @enum {string}
*/ */
@ -14327,9 +14303,8 @@ export type operations = {
* - model/name:fp16:path/to/model.safetensors * - model/name:fp16:path/to/model.safetensors
* - model/name::path/to/model.safetensors * - model/name::path/to/model.safetensors
* *
* `config` is a ModelRecordChanges object. Fields in this object will override * `config` is an optional dict containing model configuration values that will override
* the ones that are probed automatically. Pass an empty object to accept * the ones that are probed automatically.
* all the defaults.
* *
* `access_token` is an optional access token for use with Urls that require * `access_token` is an optional access token for use with Urls that require
* authentication. * authentication.
@ -14348,11 +14323,10 @@ export type operations = {
source: string; source: string;
/** @description Whether or not to install a local model in place */ /** @description Whether or not to install a local model in place */
inplace?: boolean | null; inplace?: boolean | null;
/** @description access token for the remote resource */
access_token?: string | null; access_token?: string | null;
}; };
}; };
requestBody: { requestBody?: {
content: { content: {
/** /**
* @example { * @example {
@ -14360,7 +14334,7 @@ export type operations = {
* "description": "string" * "description": "string"
* } * }
*/ */
"application/json": components["schemas"]["ModelRecordChanges"]; "application/json": Record<string, never> | null;
}; };
}; };
responses: { responses: {

View File

@ -72,16 +72,14 @@ def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_fil
def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None: def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
key = None key = None
with pytest.raises((ValidationError, InvalidModelConfigException)): with pytest.raises((ValidationError, InvalidModelConfigException)):
key = mm2_installer.register_path( key = mm2_installer.register_path(embedding_file, {"name": "banana_sushi", "type": ModelType("lora")})
embedding_file, ModelRecordChanges(name="banana_sushi", type=ModelType("lora"))
)
assert key is None assert key is None
def test_registration_meta_override_succeed(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None: def test_registration_meta_override_succeed(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
store = mm2_installer.record_store store = mm2_installer.record_store
key = mm2_installer.register_path( key = mm2_installer.register_path(
embedding_file, ModelRecordChanges(name="banana_sushi", source="fake/repo_id", key="xyzzy") embedding_file, {"name": "banana_sushi", "source": "fake/repo_id", "key": "xyzzy"}
) )
model_record = store.get_model(key) model_record = store.get_model(key)
assert model_record.name == "banana_sushi" assert model_record.name == "banana_sushi"
@ -133,7 +131,7 @@ def test_background_install(
path: Path = request.getfixturevalue(fixture_name) path: Path = request.getfixturevalue(fixture_name)
description = "Test of metadata assignment" description = "Test of metadata assignment"
source = LocalModelSource(path=path, inplace=False) source = LocalModelSource(path=path, inplace=False)
job = mm2_installer.import_model(source, config=ModelRecordChanges(description=description)) job = mm2_installer.import_model(source, config={"description": description})
assert job is not None assert job is not None
assert isinstance(job, ModelInstallJob) assert isinstance(job, ModelInstallJob)