[MM2] Use typed ModelRecordChanges for model_install() rather than untyped dict (#6645)

* [MM2] replace untyped config dict passed to install_model with typed ModelRecordChanges

- adjusted frontend to work with new schema
- used this facility to assign "starter model" names and descriptions to the installed
  models.

* documentation fix

* [MM2] replace untyped config dict passed to install_model with typed ModelRecordChanges

- adjusted frontend to work with new schema
- used this facility to assign "starter model" names and descriptions to the installed
  models.

* documentation fix

* remove v9 pnpm lockfile

* [MM2] replace untyped config dict passed to install_model with typed ModelRecordChanges

- adjusted frontend to work with new schema
- used this facility to assign "starter model" names and descriptions to the installed
  models.

* [MM2] replace untyped config dict passed to install_model with typed ModelRecordChanges

- adjusted frontend to work with new schema
- used this facility to assign "starter model" names and descriptions to the installed
  models.

* remove v9 pnpm lockfile

* regenerate schema.ts

* prettified

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
This commit is contained in:
Lincoln Stein 2024-07-23 17:41:00 -04:00 committed by GitHub
parent a221ab2fb6
commit 633bbb4e85
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 266 additions and 13546 deletions

View File

@ -6,7 +6,7 @@ import pathlib
import traceback
from copy import deepcopy
from tempfile import TemporaryDirectory
from typing import Any, Dict, List, Optional, Type
from typing import List, Optional, Type
from fastapi import Body, Path, Query, Response, UploadFile
from fastapi.responses import FileResponse, HTMLResponse
@ -430,13 +430,11 @@ async def delete_model_image(
async def install_model(
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
# TODO(MM2): Can we type this?
config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
default=None,
access_token: Optional[str] = Query(description="access token for the remote resource", default=None),
config: ModelRecordChanges = Body(
description="Object containing fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
example={"name": "string", "description": "string"},
),
access_token: Optional[str] = None,
) -> ModelInstallJob:
"""Install a model using a string identifier.
@ -451,8 +449,9 @@ async def install_model(
- model/name:fp16:path/to/model.safetensors
- model/name::path/to/model.safetensors
`config` is an optional dict containing model configuration values that will override
the ones that are probed automatically.
`config` is a ModelRecordChanges object. Fields in this object will override
the ones that are probed automatically. Pass an empty object to accept
all the defaults.
`access_token` is an optional access token for use with Urls that require
authentication.
@ -737,7 +736,7 @@ async def convert_model(
# write the converted file to the convert path
raw_model = converted_model.model
assert hasattr(raw_model, "save_pretrained")
raw_model.save_pretrained(convert_path)
raw_model.save_pretrained(convert_path) # type: ignore
assert convert_path.exists()
# temporarily rename the original safetensors file so that there is no naming conflict
@ -750,12 +749,12 @@ async def convert_model(
try:
new_key = installer.install_path(
convert_path,
config={
"name": original_name,
"description": model_config.description,
"hash": model_config.hash,
"source": model_config.source,
},
config=ModelRecordChanges(
name=original_name,
description=model_config.description,
hash=model_config.hash,
source=model_config.source,
),
)
except Exception as e:
logger.error(str(e))

View File

@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import List, Optional, Union
from pydantic.networks import AnyHttpUrl
@ -12,7 +12,7 @@ from invokeai.app.services.download import DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.app.services.model_records import ModelRecordChanges, ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig
@ -64,7 +64,7 @@ class ModelInstallServiceBase(ABC):
def register_path(
self,
model_path: Union[Path, str],
config: Optional[Dict[str, Any]] = None,
config: Optional[ModelRecordChanges] = None,
) -> str:
"""
Probe and register the model at model_path.
@ -72,7 +72,7 @@ class ModelInstallServiceBase(ABC):
This keeps the model in its current location.
:param model_path: Filesystem Path to the model.
:param config: Dict of attributes that will override autoassigned values.
:param config: ModelRecordChanges object that will override autoassigned model record values.
:returns id: The string ID of the registered model.
"""
@ -92,7 +92,7 @@ class ModelInstallServiceBase(ABC):
def install_path(
self,
model_path: Union[Path, str],
config: Optional[Dict[str, Any]] = None,
config: Optional[ModelRecordChanges] = None,
) -> str:
"""
Probe, register and install the model in the models directory.
@ -101,7 +101,7 @@ class ModelInstallServiceBase(ABC):
the models directory handled by InvokeAI.
:param model_path: Filesystem Path to the model.
:param config: Dict of attributes that will override autoassigned values.
:param config: ModelRecordChanges object that will override autoassigned model record values.
:returns id: The string ID of the registered model.
"""
@ -109,14 +109,14 @@ class ModelInstallServiceBase(ABC):
def heuristic_import(
self,
source: str,
config: Optional[Dict[str, Any]] = None,
config: Optional[ModelRecordChanges] = None,
access_token: Optional[str] = None,
inplace: Optional[bool] = False,
) -> ModelInstallJob:
r"""Install the indicated model using heuristics to interpret user intentions.
:param source: String source
:param config: Optional dict. Any fields in this dict
:param config: Optional ModelRecordChanges object. Any fields in this object
will override corresponding autoassigned probe fields in the
model's config record as described in `import_model()`.
:param access_token: Optional access token for remote sources.
@ -147,7 +147,7 @@ class ModelInstallServiceBase(ABC):
def import_model(
self,
source: ModelSource,
config: Optional[Dict[str, Any]] = None,
config: Optional[ModelRecordChanges] = None,
) -> ModelInstallJob:
"""Install the indicated model.

View File

@ -2,13 +2,14 @@ import re
import traceback
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Set, Union
from typing import Literal, Optional, Set, Union
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob
from invokeai.app.services.model_records import ModelRecordChanges
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
@ -133,8 +134,9 @@ class ModelInstallJob(BaseModel):
id: int = Field(description="Unique ID for this job")
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
config_in: Dict[str, Any] = Field(
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
config_in: ModelRecordChanges = Field(
default_factory=ModelRecordChanges,
description="Configuration information (e.g. 'description') to apply to model.",
)
config_out: Optional[AnyModelConfig] = Field(
default=None, description="After successful installation, this will hold the configuration object."

View File

@ -163,26 +163,27 @@ class ModelInstallService(ModelInstallServiceBase):
def register_path(
self,
model_path: Union[Path, str],
config: Optional[Dict[str, Any]] = None,
config: Optional[ModelRecordChanges] = None,
) -> str: # noqa D102
model_path = Path(model_path)
config = config or {}
if not config.get("source"):
config["source"] = model_path.resolve().as_posix()
config["source_type"] = ModelSourceType.Path
config = config or ModelRecordChanges()
if not config.source:
config.source = model_path.resolve().as_posix()
config.source_type = ModelSourceType.Path
return self._register(model_path, config)
def install_path(
self,
model_path: Union[Path, str],
config: Optional[Dict[str, Any]] = None,
config: Optional[ModelRecordChanges] = None,
) -> str: # noqa D102
model_path = Path(model_path)
config = config or {}
config = config or ModelRecordChanges()
info: AnyModelConfig = ModelProbe.probe(
Path(model_path), config.model_dump(), hash_algo=self._app_config.hashing_algorithm
) # type: ignore
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config, hash_algo=self._app_config.hashing_algorithm)
if preferred_name := config.get("name"):
if preferred_name := config.name:
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
dest_path = (
@ -204,7 +205,7 @@ class ModelInstallService(ModelInstallServiceBase):
def heuristic_import(
self,
source: str,
config: Optional[Dict[str, Any]] = None,
config: Optional[ModelRecordChanges] = None,
access_token: Optional[str] = None,
inplace: Optional[bool] = False,
) -> ModelInstallJob:
@ -216,7 +217,7 @@ class ModelInstallService(ModelInstallServiceBase):
source_obj.access_token = access_token
return self.import_model(source_obj, config)
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
def import_model(self, source: ModelSource, config: Optional[ModelRecordChanges] = None) -> ModelInstallJob: # noqa D102
similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state]
if similar_jobs:
self._logger.warning(f"There is already an active install job for {source}. Not enqueuing.")
@ -318,16 +319,17 @@ class ModelInstallService(ModelInstallServiceBase):
model_path = self._app_config.models_path / model_path
model_path = model_path.resolve()
config: dict[str, Any] = {}
config["name"] = model_name
config["description"] = stanza.get("description")
config = ModelRecordChanges(
name=model_name,
description=stanza.get("description"),
)
legacy_config_path = stanza.get("config")
if legacy_config_path:
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
legacy_config_path = self._app_config.root_path / legacy_config_path
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
config["config_path"] = str(legacy_config_path)
config.config_path = str(legacy_config_path)
try:
id = self.register_path(model_path=model_path, config=config)
self._logger.info(f"Migrated {model_name} with id {id}")
@ -500,11 +502,11 @@ class ModelInstallService(ModelInstallServiceBase):
job.total_bytes = self._stat_size(job.local_path)
job.bytes = job.total_bytes
self._signal_job_running(job)
job.config_in["source"] = str(job.source)
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
job.config_in.source = str(job.source)
job.config_in.source_type = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
# enter the metadata, if there is any
if isinstance(job.source_metadata, (HuggingFaceMetadata)):
job.config_in["source_api_response"] = job.source_metadata.api_response
job.config_in.source_api_response = job.source_metadata.api_response
if job.inplace:
key = self.register_path(job.local_path, job.config_in)
@ -639,11 +641,11 @@ class ModelInstallService(ModelInstallServiceBase):
return new_path
def _register(
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
) -> str:
config = config or {}
config = config or ModelRecordChanges()
info = info or ModelProbe.probe(model_path, config, hash_algo=self._app_config.hashing_algorithm)
info = info or ModelProbe.probe(model_path, config.model_dump(), hash_algo=self._app_config.hashing_algorithm) # type: ignore
model_path = model_path.resolve()
@ -674,11 +676,13 @@ class ModelInstallService(ModelInstallServiceBase):
precision = TorchDevice.choose_torch_dtype()
return ModelRepoVariant.FP16 if precision == torch.float16 else None
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
def _import_local_model(
self, source: LocalModelSource, config: Optional[ModelRecordChanges] = None
) -> ModelInstallJob:
return ModelInstallJob(
id=self._next_id(),
source=source,
config_in=config or {},
config_in=config or ModelRecordChanges(),
local_path=Path(source.path),
inplace=source.inplace or False,
)
@ -686,7 +690,7 @@ class ModelInstallService(ModelInstallServiceBase):
def _import_from_hf(
self,
source: HFModelSource,
config: Optional[Dict[str, Any]] = None,
config: Optional[ModelRecordChanges] = None,
) -> ModelInstallJob:
# Add user's cached access token to HuggingFace requests
if source.access_token is None:
@ -702,7 +706,7 @@ class ModelInstallService(ModelInstallServiceBase):
def _import_from_url(
self,
source: URLModelSource,
config: Optional[Dict[str, Any]],
config: Optional[ModelRecordChanges] = None,
) -> ModelInstallJob:
remote_files, metadata = self._remote_files_from_source(source)
return self._import_remote_model(
@ -717,7 +721,7 @@ class ModelInstallService(ModelInstallServiceBase):
source: HFModelSource | URLModelSource,
remote_files: List[RemoteModelFile],
metadata: Optional[AnyModelRepoMetadata],
config: Optional[Dict[str, Any]],
config: Optional[ModelRecordChanges],
) -> ModelInstallJob:
if len(remote_files) == 0:
raise ValueError(f"{source}: No downloadable files found")
@ -730,7 +734,7 @@ class ModelInstallService(ModelInstallServiceBase):
install_job = ModelInstallJob(
id=self._next_id(),
source=source,
config_in=config or {},
config_in=config or ModelRecordChanges(),
source_metadata=metadata,
local_path=destdir, # local path may change once the download has started due to content-disposition handling
bytes=0,

View File

@ -18,6 +18,7 @@ from invokeai.backend.model_manager.config import (
ControlAdapterDefaultSettings,
MainModelDefaultSettings,
ModelFormat,
ModelSourceType,
ModelType,
ModelVariantType,
SchedulerPredictionType,
@ -66,10 +67,16 @@ class ModelRecordChanges(BaseModelExcludeNull):
"""A set of changes to apply to a model."""
# Changes applicable to all models
source: Optional[str] = Field(description="original source of the model", default=None)
source_type: Optional[ModelSourceType] = Field(description="type of model source", default=None)
source_api_response: Optional[str] = Field(description="metadata from remote source", default=None)
name: Optional[str] = Field(description="Name of the model.", default=None)
path: Optional[str] = Field(description="Path to the model.", default=None)
description: Optional[str] = Field(description="Model description", default=None)
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
type: Optional[ModelType] = Field(description="Type of model", default=None)
key: Optional[str] = Field(description="Database ID for this model", default=None)
hash: Optional[str] = Field(description="hash of model file", default=None)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
description="Default settings for this model", default=None

View File

@ -354,7 +354,7 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase):
"""Model config for CLIPVision."""
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
format: Literal[ModelFormat.Diffusers]
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
@ -365,7 +365,7 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase):
"""Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
format: Literal[ModelFormat.Diffusers]
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:

View File

@ -155,5 +155,8 @@
"vite-plugin-eslint": "^1.8.1",
"vite-tsconfig-paths": "^4.3.2",
"vitest": "^1.6.0"
},
"engines": {
"pnpm": "8"
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,11 +1,9 @@
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useInstallModelMutation } from 'services/api/endpoints/models';
import { type InstallModelArg, useInstallModelMutation } from 'services/api/endpoints/models';
type InstallModelArg = {
source: string;
inplace?: boolean;
type InstallModelArgWithCallbacks = InstallModelArg & {
onSuccess?: () => void;
onError?: (error: unknown) => void;
};
@ -15,8 +13,9 @@ export const useInstallModel = () => {
const [_installModel, request] = useInstallModelMutation();
const installModel = useCallback(
({ source, inplace, onSuccess, onError }: InstallModelArg) => {
_installModel({ source, inplace })
({ source, inplace, config, onSuccess, onError }: InstallModelArgWithCallbacks) => {
config ||= {};
_installModel({ source, inplace, config })
.unwrap()
.then((_) => {
if (onSuccess) {

View File

@ -12,17 +12,19 @@ type Props = {
export const StarterModelsResultItem = ({ result }: Props) => {
const { t } = useTranslation();
const allSources = useMemo(() => {
const _allSources = [result.source];
const _allSources = [{ source: result.source, config: { name: result.name, description: result.description } }];
if (result.dependencies) {
_allSources.push(...result.dependencies.map((d) => d.source));
for (const d of result.dependencies) {
_allSources.push({ source: d.source, config: { name: d.name, description: d.description } });
}
}
return _allSources;
}, [result]);
const [installModel] = useInstallModel();
const onClick = useCallback(() => {
for (const source of allSources) {
installModel({ source });
for (const { config, source } of allSources) {
installModel({ config, source });
}
}, [allSources, installModel]);

View File

@ -39,9 +39,10 @@ type DeleteModelImageResponse = void;
type ConvertMainModelResponse =
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
type InstallModelArg = {
export type InstallModelArg = {
source: paths['/api/v2/models/install']['post']['parameters']['query']['source'];
inplace?: paths['/api/v2/models/install']['post']['parameters']['query']['inplace'];
config?: paths['/api/v2/models/install']['post']['requestBody']['content']['application/json'];
};
type InstallModelResponse = paths['/api/v2/models/install']['post']['responses']['201']['content']['application/json'];
@ -124,11 +125,12 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [{ type: 'ModelConfig', id: LIST_TAG }],
}),
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
query: ({ source, inplace = true }) => {
query: ({ source, inplace = true, config }) => {
return {
url: buildModelsUrl('install'),
params: { source, inplace },
method: 'POST',
body: config,
};
},
invalidatesTags: ['ModelInstalls'],

View File

@ -103,8 +103,9 @@ export type paths = {
* - model/name:fp16:path/to/model.safetensors
* - model/name::path/to/model.safetensors
*
* `config` is an optional dict containing model configuration values that will override
* the ones that are probed automatically.
* `config` is a ModelRecordChanges object. Fields in this object will override
* the ones that are probed automatically. Pass an empty object to accept
* all the defaults.
*
* `access_token` is an optional access token for use with Urls that require
* authentication.
@ -1589,6 +1590,7 @@ export type components = {
cover_image?: string | null;
/**
* Format
* @default diffusers
* @constant
* @enum {string}
*/
@ -3171,7 +3173,7 @@ export type components = {
/**
* Fp32
* @description Whether or not to use full float32 precision
* @default false
* @default true
*/
fp32?: boolean;
/**
@ -3254,7 +3256,7 @@ export type components = {
/**
* Fp32
* @description Whether or not to use full float32 precision
* @default false
* @default true
*/
fp32?: boolean;
/**
@ -6575,7 +6577,7 @@ export type components = {
/**
* Fp32
* @description Whether or not to use full float32 precision
* @default false
* @default true
*/
fp32?: boolean;
/**
@ -7304,146 +7306,146 @@ export type components = {
project_id: string | null;
};
InvocationOutputMap: {
integer: components["schemas"]["IntegerOutput"];
heuristic_resize: components["schemas"]["ImageOutput"];
range_of_size: components["schemas"]["IntegerCollectionOutput"];
sdxl_compel_prompt: components["schemas"]["ConditioningOutput"];
midas_depth_image_processor: components["schemas"]["ImageOutput"];
dw_openpose_image_processor: components["schemas"]["ImageOutput"];
color: components["schemas"]["ColorOutput"];
merge_tiles_to_image: components["schemas"]["ImageOutput"];
merge_metadata: components["schemas"]["MetadataOutput"];
string_join_three: components["schemas"]["StringOutput"];
lblend: components["schemas"]["LatentsOutput"];
img_channel_multiply: components["schemas"]["ImageOutput"];
float_collection: components["schemas"]["FloatCollectionOutput"];
face_off: components["schemas"]["FaceOffOutput"];
mlsd_image_processor: components["schemas"]["ImageOutput"];
color_map_image_processor: components["schemas"]["ImageOutput"];
img_paste: components["schemas"]["ImageOutput"];
denoise_latents: components["schemas"]["LatentsOutput"];
model_identifier: components["schemas"]["ModelIdentifierOutput"];
sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"];
img_resize: components["schemas"]["ImageOutput"];
float_to_int: components["schemas"]["IntegerOutput"];
img_scale: components["schemas"]["ImageOutput"];
string_collection: components["schemas"]["StringCollectionOutput"];
compel: components["schemas"]["ConditioningOutput"];
calculate_image_tiles_min_overlap: components["schemas"]["CalculateImageTilesOutput"];
infill_cv2: components["schemas"]["ImageOutput"];
string_join: components["schemas"]["StringOutput"];
lineart_anime_image_processor: components["schemas"]["ImageOutput"];
infill_rgba: components["schemas"]["ImageOutput"];
sdxl_lora_collection_loader: components["schemas"]["SDXLLoRALoaderOutput"];
lblend: components["schemas"]["LatentsOutput"];
lscale: components["schemas"]["LatentsOutput"];
normalbae_image_processor: components["schemas"]["ImageOutput"];
float_collection: components["schemas"]["FloatCollectionOutput"];
range: components["schemas"]["IntegerCollectionOutput"];
infill_patchmatch: components["schemas"]["ImageOutput"];
mlsd_image_processor: components["schemas"]["ImageOutput"];
string_replace: components["schemas"]["StringOutput"];
image_mask_to_tensor: components["schemas"]["MaskOutput"];
depth_anything_image_processor: components["schemas"]["ImageOutput"];
infill_lama: components["schemas"]["ImageOutput"];
metadata_item: components["schemas"]["MetadataItemOutput"];
lora_loader: components["schemas"]["LoRALoaderOutput"];
latents_collection: components["schemas"]["LatentsCollectionOutput"];
alpha_mask_to_tensor: components["schemas"]["MaskOutput"];
rand_float: components["schemas"]["FloatOutput"];
noise: components["schemas"]["NoiseOutput"];
face_mask_detection: components["schemas"]["FaceMaskOutput"];
ideal_size: components["schemas"]["IdealSizeOutput"];
lora_collection_loader: components["schemas"]["LoRALoaderOutput"];
zoe_depth_image_processor: components["schemas"]["ImageOutput"];
tile_to_properties: components["schemas"]["TileToPropertiesOutput"];
img_blur: components["schemas"]["ImageOutput"];
merge_tiles_to_image: components["schemas"]["ImageOutput"];
latents_collection: components["schemas"]["LatentsCollectionOutput"];
img_mul: components["schemas"]["ImageOutput"];
image_mask_to_tensor: components["schemas"]["MaskOutput"];
conditioning_collection: components["schemas"]["ConditioningCollectionOutput"];
cv_inpaint: components["schemas"]["ImageOutput"];
img_pad_crop: components["schemas"]["ImageOutput"];
lresize: components["schemas"]["LatentsOutput"];
conditioning: components["schemas"]["ConditioningOutput"];
dynamic_prompt: components["schemas"]["StringCollectionOutput"];
tomask: components["schemas"]["ImageOutput"];
mul: components["schemas"]["IntegerOutput"];
seamless: components["schemas"]["SeamlessModeOutput"];
canny_image_processor: components["schemas"]["ImageOutput"];
metadata_item: components["schemas"]["MetadataItemOutput"];
add: components["schemas"]["IntegerOutput"];
crop_latents: components["schemas"]["LatentsOutput"];
integer_collection: components["schemas"]["IntegerCollectionOutput"];
sdxl_model_loader: components["schemas"]["SDXLModelLoaderOutput"];
string_split: components["schemas"]["String2Output"];
tile_image_processor: components["schemas"]["ImageOutput"];
infill_cv2: components["schemas"]["ImageOutput"];
tiled_multi_diffusion_denoise_latents: components["schemas"]["LatentsOutput"];
collect: components["schemas"]["CollectInvocationOutput"];
image_collection: components["schemas"]["ImageCollectionOutput"];
save_image: components["schemas"]["ImageOutput"];
controlnet: components["schemas"]["ControlOutput"];
float_math: components["schemas"]["FloatOutput"];
sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"];
i2l: components["schemas"]["LatentsOutput"];
infill_lama: components["schemas"]["ImageOutput"];
sub: components["schemas"]["IntegerOutput"];
div: components["schemas"]["IntegerOutput"];
face_mask_detection: components["schemas"]["FaceMaskOutput"];
esrgan: components["schemas"]["ImageOutput"];
mask_combine: components["schemas"]["ImageOutput"];
ip_adapter: components["schemas"]["IPAdapterOutput"];
blank_image: components["schemas"]["ImageOutput"];
heuristic_resize: components["schemas"]["ImageOutput"];
rand_int: components["schemas"]["IntegerOutput"];
lora_selector: components["schemas"]["LoRASelectorOutput"];
unsharp_mask: components["schemas"]["ImageOutput"];
face_identifier: components["schemas"]["ImageOutput"];
sdxl_compel_prompt: components["schemas"]["ConditioningOutput"];
infill_patchmatch: components["schemas"]["ImageOutput"];
img_nsfw: components["schemas"]["ImageOutput"];
lineart_anime_image_processor: components["schemas"]["ImageOutput"];
compel: components["schemas"]["ConditioningOutput"];
rectangle_mask: components["schemas"]["MaskOutput"];
lora_collection_loader: components["schemas"]["LoRALoaderOutput"];
freeu: components["schemas"]["UNetOutput"];
img_hue_adjust: components["schemas"]["ImageOutput"];
pidi_image_processor: components["schemas"]["ImageOutput"];
content_shuffle_image_processor: components["schemas"]["ImageOutput"];
mediapipe_face_processor: components["schemas"]["ImageOutput"];
string_split_neg: components["schemas"]["StringPosNegOutput"];
img_conv: components["schemas"]["ImageOutput"];
lora_loader: components["schemas"]["LoRALoaderOutput"];
color_correct: components["schemas"]["ImageOutput"];
img_ilerp: components["schemas"]["ImageOutput"];
noise: components["schemas"]["NoiseOutput"];
float_range: components["schemas"]["FloatCollectionOutput"];
dw_openpose_image_processor: components["schemas"]["ImageOutput"];
float_to_int: components["schemas"]["IntegerOutput"];
invert_tensor_mask: components["schemas"]["MaskOutput"];
random_range: components["schemas"]["IntegerCollectionOutput"];
latents: components["schemas"]["LatentsOutput"];
leres_image_processor: components["schemas"]["ImageOutput"];
t2i_adapter: components["schemas"]["T2IAdapterOutput"];
pair_tile_image: components["schemas"]["PairTileImageOutput"];
mask_edge: components["schemas"]["ImageOutput"];
metadata: components["schemas"]["MetadataOutput"];
string_join: components["schemas"]["StringOutput"];
core_metadata: components["schemas"]["MetadataOutput"];
canvas_paste_back: components["schemas"]["ImageOutput"];
sdxl_lora_collection_loader: components["schemas"]["SDXLLoRALoaderOutput"];
img_channel_offset: components["schemas"]["ImageOutput"];
lineart_image_processor: components["schemas"]["ImageOutput"];
midas_depth_image_processor: components["schemas"]["ImageOutput"];
lscale: components["schemas"]["LatentsOutput"];
string: components["schemas"]["StringOutput"];
integer: components["schemas"]["IntegerOutput"];
string_replace: components["schemas"]["StringOutput"];
depth_anything_image_processor: components["schemas"]["ImageOutput"];
main_model_loader: components["schemas"]["ModelLoaderOutput"];
image: components["schemas"]["ImageOutput"];
prompt_from_file: components["schemas"]["StringCollectionOutput"];
sdxl_lora_loader: components["schemas"]["SDXLLoRALoaderOutput"];
mask_from_id: components["schemas"]["ImageOutput"];
normalbae_image_processor: components["schemas"]["ImageOutput"];
infill_rgba: components["schemas"]["ImageOutput"];
step_param_easing: components["schemas"]["FloatCollectionOutput"];
hed_image_processor: components["schemas"]["ImageOutput"];
img_chan: components["schemas"]["ImageOutput"];
float: components["schemas"]["FloatOutput"];
boolean_collection: components["schemas"]["BooleanCollectionOutput"];
segment_anything_processor: components["schemas"]["ImageOutput"];
range_of_size: components["schemas"]["IntegerCollectionOutput"];
boolean: components["schemas"]["BooleanOutput"];
iterate: components["schemas"]["IterateInvocationOutput"];
denoise_latents: components["schemas"]["LatentsOutput"];
calculate_image_tiles_min_overlap: components["schemas"]["CalculateImageTilesOutput"];
color: components["schemas"]["ColorOutput"];
calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"];
scheduler: components["schemas"]["SchedulerOutput"];
rand_float: components["schemas"]["FloatOutput"];
create_denoise_mask: components["schemas"]["DenoiseMaskOutput"];
range: components["schemas"]["IntegerCollectionOutput"];
img_watermark: components["schemas"]["ImageOutput"];
spandrel_image_to_image: components["schemas"]["ImageOutput"];
show_image: components["schemas"]["ImageOutput"];
string_collection: components["schemas"]["StringCollectionOutput"];
infill_tile: components["schemas"]["ImageOutput"];
clip_skip: components["schemas"]["CLIPSkipInvocationOutput"];
sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"];
ideal_size: components["schemas"]["IdealSizeOutput"];
img_lerp: components["schemas"]["ImageOutput"];
l2i: components["schemas"]["ImageOutput"];
create_gradient_mask: components["schemas"]["GradientMaskOutput"];
vae_loader: components["schemas"]["VAEOutput"];
calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"];
alpha_mask_to_tensor: components["schemas"]["MaskOutput"];
integer_math: components["schemas"]["IntegerOutput"];
model_identifier: components["schemas"]["ModelIdentifierOutput"];
save_image: components["schemas"]["ImageOutput"];
collect: components["schemas"]["CollectInvocationOutput"];
face_identifier: components["schemas"]["ImageOutput"];
img_blur: components["schemas"]["ImageOutput"];
img_paste: components["schemas"]["ImageOutput"];
segment_anything_processor: components["schemas"]["ImageOutput"];
add: components["schemas"]["IntegerOutput"];
tiled_multi_diffusion_denoise_latents: components["schemas"]["LatentsOutput"];
img_crop: components["schemas"]["ImageOutput"];
img_resize: components["schemas"]["ImageOutput"];
conditioning: components["schemas"]["ConditioningOutput"];
esrgan: components["schemas"]["ImageOutput"];
lineart_image_processor: components["schemas"]["ImageOutput"];
mul: components["schemas"]["IntegerOutput"];
img_nsfw: components["schemas"]["ImageOutput"];
img_mul: components["schemas"]["ImageOutput"];
sdxl_model_loader: components["schemas"]["SDXLModelLoaderOutput"];
infill_tile: components["schemas"]["ImageOutput"];
float: components["schemas"]["FloatOutput"];
sdxl_lora_loader: components["schemas"]["SDXLLoRALoaderOutput"];
scheduler: components["schemas"]["SchedulerOutput"];
tomask: components["schemas"]["ImageOutput"];
conditioning_collection: components["schemas"]["ConditioningCollectionOutput"];
img_ilerp: components["schemas"]["ImageOutput"];
integer_math: components["schemas"]["IntegerOutput"];
lora_selector: components["schemas"]["LoRASelectorOutput"];
sub: components["schemas"]["IntegerOutput"];
crop_latents: components["schemas"]["LatentsOutput"];
string_join_three: components["schemas"]["StringOutput"];
cv_inpaint: components["schemas"]["ImageOutput"];
main_model_loader: components["schemas"]["ModelLoaderOutput"];
hed_image_processor: components["schemas"]["ImageOutput"];
create_gradient_mask: components["schemas"]["GradientMaskOutput"];
create_denoise_mask: components["schemas"]["DenoiseMaskOutput"];
mask_combine: components["schemas"]["ImageOutput"];
img_pad_crop: components["schemas"]["ImageOutput"];
freeu: components["schemas"]["UNetOutput"];
lresize: components["schemas"]["LatentsOutput"];
metadata: components["schemas"]["MetadataOutput"];
color_map_image_processor: components["schemas"]["ImageOutput"];
image_collection: components["schemas"]["ImageCollectionOutput"];
l2i: components["schemas"]["ImageOutput"];
show_image: components["schemas"]["ImageOutput"];
t2i_adapter: components["schemas"]["T2IAdapterOutput"];
round_float: components["schemas"]["FloatOutput"];
canvas_paste_back: components["schemas"]["ImageOutput"];
img_hue_adjust: components["schemas"]["ImageOutput"];
sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"];
spandrel_image_to_image: components["schemas"]["ImageOutput"];
step_param_easing: components["schemas"]["FloatCollectionOutput"];
calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"];
color_correct: components["schemas"]["ImageOutput"];
float_range: components["schemas"]["FloatCollectionOutput"];
mediapipe_face_processor: components["schemas"]["ImageOutput"];
prompt_from_file: components["schemas"]["StringCollectionOutput"];
random_range: components["schemas"]["IntegerCollectionOutput"];
invert_tensor_mask: components["schemas"]["MaskOutput"];
img_conv: components["schemas"]["ImageOutput"];
seamless: components["schemas"]["SeamlessModeOutput"];
ip_adapter: components["schemas"]["IPAdapterOutput"];
i2l: components["schemas"]["LatentsOutput"];
integer_collection: components["schemas"]["IntegerCollectionOutput"];
vae_loader: components["schemas"]["VAEOutput"];
leres_image_processor: components["schemas"]["ImageOutput"];
blank_image: components["schemas"]["ImageOutput"];
mask_from_id: components["schemas"]["ImageOutput"];
pair_tile_image: components["schemas"]["PairTileImageOutput"];
dynamic_prompt: components["schemas"]["StringCollectionOutput"];
mask_edge: components["schemas"]["ImageOutput"];
img_channel_multiply: components["schemas"]["ImageOutput"];
controlnet: components["schemas"]["ControlOutput"];
latents: components["schemas"]["LatentsOutput"];
unsharp_mask: components["schemas"]["ImageOutput"];
canny_image_processor: components["schemas"]["ImageOutput"];
core_metadata: components["schemas"]["MetadataOutput"];
div: components["schemas"]["IntegerOutput"];
string_split: components["schemas"]["String2Output"];
img_chan: components["schemas"]["ImageOutput"];
img_channel_offset: components["schemas"]["ImageOutput"];
calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"];
rand_int: components["schemas"]["IntegerOutput"];
string_split_neg: components["schemas"]["StringPosNegOutput"];
face_off: components["schemas"]["FaceOffOutput"];
boolean: components["schemas"]["BooleanOutput"];
string: components["schemas"]["StringOutput"];
float_math: components["schemas"]["FloatOutput"];
pidi_image_processor: components["schemas"]["ImageOutput"];
img_watermark: components["schemas"]["ImageOutput"];
content_shuffle_image_processor: components["schemas"]["ImageOutput"];
iterate: components["schemas"]["IterateInvocationOutput"];
img_lerp: components["schemas"]["ImageOutput"];
image: components["schemas"]["ImageOutput"];
rectangle_mask: components["schemas"]["MaskOutput"];
tile_image_processor: components["schemas"]["ImageOutput"];
boolean_collection: components["schemas"]["BooleanCollectionOutput"];
};
/**
* InvocationStartedEvent
@ -7790,7 +7792,7 @@ export type components = {
/**
* Fp32
* @description Whether or not to use full float32 precision
* @default false
* @default true
*/
fp32?: boolean;
/**
@ -9594,11 +9596,8 @@ export type components = {
* @description Information about why the job failed
*/
error_reason?: string | null;
/**
* Config In
* @description Configuration information (e.g. 'description') to apply to model.
*/
config_in?: Record<string, never>;
/** @description Configuration information (e.g. 'description') to apply to model. */
config_in?: components["schemas"]["ModelRecordChanges"];
/**
* Config Out
* @description After successful installation, this will hold the configuration object.
@ -9750,6 +9749,18 @@ export type components = {
* @description A set of changes to apply to a model.
*/
ModelRecordChanges: {
/**
* Source
* @description original source of the model
*/
source?: string | null;
/** @description type of model source */
source_type?: components["schemas"]["ModelSourceType"] | null;
/**
* Source Api Response
* @description metadata from remote source
*/
source_api_response?: string | null;
/**
* Name
* @description Name of the model.
@ -9767,6 +9778,18 @@ export type components = {
description?: string | null;
/** @description The base model. */
base?: components["schemas"]["BaseModelType"] | null;
/** @description Type of model */
type?: components["schemas"]["ModelType"] | null;
/**
* Key
* @description Database ID for this model
*/
key?: string | null;
/**
* Hash
* @description hash of model file
*/
hash?: string | null;
/**
* Trigger Phrases
* @description Set of trigger phrases for this model
@ -12610,6 +12633,7 @@ export type components = {
cover_image?: string | null;
/**
* Format
* @default diffusers
* @constant
* @enum {string}
*/
@ -14303,8 +14327,9 @@ export type operations = {
* - model/name:fp16:path/to/model.safetensors
* - model/name::path/to/model.safetensors
*
* `config` is an optional dict containing model configuration values that will override
* the ones that are probed automatically.
* `config` is a ModelRecordChanges object. Fields in this object will override
* the ones that are probed automatically. Pass an empty object to accept
* all the defaults.
*
* `access_token` is an optional access token for use with Urls that require
* authentication.
@ -14323,10 +14348,11 @@ export type operations = {
source: string;
/** @description Whether or not to install a local model in place */
inplace?: boolean | null;
/** @description access token for the remote resource */
access_token?: string | null;
};
};
requestBody?: {
requestBody: {
content: {
/**
* @example {
@ -14334,7 +14360,7 @@ export type operations = {
* "description": "string"
* }
*/
"application/json": Record<string, never> | null;
"application/json": components["schemas"]["ModelRecordChanges"];
};
};
responses: {

View File

@ -72,14 +72,16 @@ def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_fil
def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
key = None
with pytest.raises((ValidationError, InvalidModelConfigException)):
key = mm2_installer.register_path(embedding_file, {"name": "banana_sushi", "type": ModelType("lora")})
key = mm2_installer.register_path(
embedding_file, ModelRecordChanges(name="banana_sushi", type=ModelType("lora"))
)
assert key is None
def test_registration_meta_override_succeed(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
store = mm2_installer.record_store
key = mm2_installer.register_path(
embedding_file, {"name": "banana_sushi", "source": "fake/repo_id", "key": "xyzzy"}
embedding_file, ModelRecordChanges(name="banana_sushi", source="fake/repo_id", key="xyzzy")
)
model_record = store.get_model(key)
assert model_record.name == "banana_sushi"
@ -131,7 +133,7 @@ def test_background_install(
path: Path = request.getfixturevalue(fixture_name)
description = "Test of metadata assignment"
source = LocalModelSource(path=path, inplace=False)
job = mm2_installer.import_model(source, config={"description": description})
job = mm2_installer.import_model(source, config=ModelRecordChanges(description=description))
assert job is not None
assert isinstance(job, ModelInstallJob)