Merge branch 'main' into stalker7779/modular_seamless

This commit is contained in:
Ryan Dick
2024-07-28 13:49:43 -04:00
committed by GitHub
134 changed files with 3982 additions and 2668 deletions

View File

@ -6,7 +6,7 @@ import pathlib
import traceback
from copy import deepcopy
from tempfile import TemporaryDirectory
from typing import Any, Dict, List, Optional, Type
from typing import List, Optional, Type
from fastapi import Body, Path, Query, Response, UploadFile
from fastapi.responses import FileResponse, HTMLResponse
@ -430,13 +430,11 @@ async def delete_model_image(
async def install_model(
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
# TODO(MM2): Can we type this?
config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
default=None,
access_token: Optional[str] = Query(description="access token for the remote resource", default=None),
config: ModelRecordChanges = Body(
description="Object containing fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
example={"name": "string", "description": "string"},
),
access_token: Optional[str] = None,
) -> ModelInstallJob:
"""Install a model using a string identifier.
@ -451,8 +449,9 @@ async def install_model(
- model/name:fp16:path/to/model.safetensors
- model/name::path/to/model.safetensors
`config` is an optional dict containing model configuration values that will override
the ones that are probed automatically.
`config` is a ModelRecordChanges object. Fields in this object will override
the ones that are probed automatically. Pass an empty object to accept
all the defaults.
`access_token` is an optional access token for use with Urls that require
authentication.
@ -737,7 +736,7 @@ async def convert_model(
# write the converted file to the convert path
raw_model = converted_model.model
assert hasattr(raw_model, "save_pretrained")
raw_model.save_pretrained(convert_path)
raw_model.save_pretrained(convert_path) # type: ignore
assert convert_path.exists()
# temporarily rename the original safetensors file so that there is no naming conflict
@ -750,12 +749,12 @@ async def convert_model(
try:
new_key = installer.install_path(
convert_path,
config={
"name": original_name,
"description": model_config.description,
"hash": model_config.hash,
"source": model_config.source,
},
config=ModelRecordChanges(
name=original_name,
description=model_config.description,
hash=model_config.hash,
source=model_config.source,
),
)
except Exception as e:
logger.error(str(e))

View File

@ -23,7 +23,7 @@ from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
from invokeai.backend.tiles.utils import TBLR, Tile
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.2.0")
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.3.0")
class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""
@ -36,16 +36,6 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
tile_size: int = InputField(
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."
)
scale: float = InputField(
default=4.0,
gt=0.0,
le=16.0,
description="The final scale of the output image. If the model does not upscale the image, this will be ignored.",
)
fit_to_multiple_of_8: bool = InputField(
default=False,
description="If true, the output image will be resized to the nearest multiple of 8 in both dimensions.",
)
@classmethod
def scale_tile(cls, tile: Tile, scale: int) -> Tile:
@ -152,6 +142,47 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
return pil_image
@torch.inference_mode()
def invoke(self, context: InvocationContext) -> ImageOutput:
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
# revisit this.
image = context.images.get_pil(self.image.image_name, mode="RGB")
# Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model)
# Do the upscaling.
with spandrel_model_info as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel)
# Upscale the image
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto)
@invocation(
"spandrel_image_to_image_autoscale",
title="Image-to-Image (Autoscale)",
tags=["upscale"],
category="upscale",
version="1.0.0",
)
class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel) until the target scale is reached."""
scale: float = InputField(
default=4.0,
gt=0.0,
le=16.0,
description="The final scale of the output image. If the model does not upscale the image, this will be ignored.",
)
fit_to_multiple_of_8: bool = InputField(
default=False,
description="If true, the output image will be resized to the nearest multiple of 8 in both dimensions.",
)
@torch.inference_mode()
def invoke(self, context: InvocationContext) -> ImageOutput:
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to

View File

@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import List, Optional, Union
from pydantic.networks import AnyHttpUrl
@ -12,7 +12,7 @@ from invokeai.app.services.download import DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.app.services.model_records import ModelRecordChanges, ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig
@ -64,7 +64,7 @@ class ModelInstallServiceBase(ABC):
def register_path(
self,
model_path: Union[Path, str],
config: Optional[Dict[str, Any]] = None,
config: Optional[ModelRecordChanges] = None,
) -> str:
"""
Probe and register the model at model_path.
@ -72,7 +72,7 @@ class ModelInstallServiceBase(ABC):
This keeps the model in its current location.
:param model_path: Filesystem Path to the model.
:param config: Dict of attributes that will override autoassigned values.
:param config: ModelRecordChanges object that will override autoassigned model record values.
:returns id: The string ID of the registered model.
"""
@ -92,7 +92,7 @@ class ModelInstallServiceBase(ABC):
def install_path(
self,
model_path: Union[Path, str],
config: Optional[Dict[str, Any]] = None,
config: Optional[ModelRecordChanges] = None,
) -> str:
"""
Probe, register and install the model in the models directory.
@ -101,7 +101,7 @@ class ModelInstallServiceBase(ABC):
the models directory handled by InvokeAI.
:param model_path: Filesystem Path to the model.
:param config: Dict of attributes that will override autoassigned values.
:param config: ModelRecordChanges object that will override autoassigned model record values.
:returns id: The string ID of the registered model.
"""
@ -109,14 +109,14 @@ class ModelInstallServiceBase(ABC):
def heuristic_import(
self,
source: str,
config: Optional[Dict[str, Any]] = None,
config: Optional[ModelRecordChanges] = None,
access_token: Optional[str] = None,
inplace: Optional[bool] = False,
) -> ModelInstallJob:
r"""Install the indicated model using heuristics to interpret user intentions.
:param source: String source
:param config: Optional dict. Any fields in this dict
:param config: Optional ModelRecordChanges object. Any fields in this object
will override corresponding autoassigned probe fields in the
model's config record as described in `import_model()`.
:param access_token: Optional access token for remote sources.
@ -147,7 +147,7 @@ class ModelInstallServiceBase(ABC):
def import_model(
self,
source: ModelSource,
config: Optional[Dict[str, Any]] = None,
config: Optional[ModelRecordChanges] = None,
) -> ModelInstallJob:
"""Install the indicated model.

View File

@ -2,13 +2,14 @@ import re
import traceback
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Set, Union
from typing import Literal, Optional, Set, Union
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob
from invokeai.app.services.model_records import ModelRecordChanges
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
@ -133,8 +134,9 @@ class ModelInstallJob(BaseModel):
id: int = Field(description="Unique ID for this job")
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
config_in: Dict[str, Any] = Field(
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
config_in: ModelRecordChanges = Field(
default_factory=ModelRecordChanges,
description="Configuration information (e.g. 'description') to apply to model.",
)
config_out: Optional[AnyModelConfig] = Field(
default=None, description="After successful installation, this will hold the configuration object."

View File

@ -163,26 +163,27 @@ class ModelInstallService(ModelInstallServiceBase):
def register_path(
self,
model_path: Union[Path, str],
config: Optional[Dict[str, Any]] = None,
config: Optional[ModelRecordChanges] = None,
) -> str: # noqa D102
model_path = Path(model_path)
config = config or {}
if not config.get("source"):
config["source"] = model_path.resolve().as_posix()
config["source_type"] = ModelSourceType.Path
config = config or ModelRecordChanges()
if not config.source:
config.source = model_path.resolve().as_posix()
config.source_type = ModelSourceType.Path
return self._register(model_path, config)
def install_path(
self,
model_path: Union[Path, str],
config: Optional[Dict[str, Any]] = None,
config: Optional[ModelRecordChanges] = None,
) -> str: # noqa D102
model_path = Path(model_path)
config = config or {}
config = config or ModelRecordChanges()
info: AnyModelConfig = ModelProbe.probe(
Path(model_path), config.model_dump(), hash_algo=self._app_config.hashing_algorithm
) # type: ignore
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config, hash_algo=self._app_config.hashing_algorithm)
if preferred_name := config.get("name"):
if preferred_name := config.name:
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
dest_path = (
@ -204,7 +205,7 @@ class ModelInstallService(ModelInstallServiceBase):
def heuristic_import(
self,
source: str,
config: Optional[Dict[str, Any]] = None,
config: Optional[ModelRecordChanges] = None,
access_token: Optional[str] = None,
inplace: Optional[bool] = False,
) -> ModelInstallJob:
@ -216,7 +217,7 @@ class ModelInstallService(ModelInstallServiceBase):
source_obj.access_token = access_token
return self.import_model(source_obj, config)
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
def import_model(self, source: ModelSource, config: Optional[ModelRecordChanges] = None) -> ModelInstallJob: # noqa D102
similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state]
if similar_jobs:
self._logger.warning(f"There is already an active install job for {source}. Not enqueuing.")
@ -318,16 +319,17 @@ class ModelInstallService(ModelInstallServiceBase):
model_path = self._app_config.models_path / model_path
model_path = model_path.resolve()
config: dict[str, Any] = {}
config["name"] = model_name
config["description"] = stanza.get("description")
config = ModelRecordChanges(
name=model_name,
description=stanza.get("description"),
)
legacy_config_path = stanza.get("config")
if legacy_config_path:
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
legacy_config_path = self._app_config.root_path / legacy_config_path
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
config["config_path"] = str(legacy_config_path)
config.config_path = str(legacy_config_path)
try:
id = self.register_path(model_path=model_path, config=config)
self._logger.info(f"Migrated {model_name} with id {id}")
@ -500,11 +502,11 @@ class ModelInstallService(ModelInstallServiceBase):
job.total_bytes = self._stat_size(job.local_path)
job.bytes = job.total_bytes
self._signal_job_running(job)
job.config_in["source"] = str(job.source)
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
job.config_in.source = str(job.source)
job.config_in.source_type = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
# enter the metadata, if there is any
if isinstance(job.source_metadata, (HuggingFaceMetadata)):
job.config_in["source_api_response"] = job.source_metadata.api_response
job.config_in.source_api_response = job.source_metadata.api_response
if job.inplace:
key = self.register_path(job.local_path, job.config_in)
@ -639,11 +641,11 @@ class ModelInstallService(ModelInstallServiceBase):
return new_path
def _register(
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
) -> str:
config = config or {}
config = config or ModelRecordChanges()
info = info or ModelProbe.probe(model_path, config, hash_algo=self._app_config.hashing_algorithm)
info = info or ModelProbe.probe(model_path, config.model_dump(), hash_algo=self._app_config.hashing_algorithm) # type: ignore
model_path = model_path.resolve()
@ -674,11 +676,13 @@ class ModelInstallService(ModelInstallServiceBase):
precision = TorchDevice.choose_torch_dtype()
return ModelRepoVariant.FP16 if precision == torch.float16 else None
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
def _import_local_model(
self, source: LocalModelSource, config: Optional[ModelRecordChanges] = None
) -> ModelInstallJob:
return ModelInstallJob(
id=self._next_id(),
source=source,
config_in=config or {},
config_in=config or ModelRecordChanges(),
local_path=Path(source.path),
inplace=source.inplace or False,
)
@ -686,7 +690,7 @@ class ModelInstallService(ModelInstallServiceBase):
def _import_from_hf(
self,
source: HFModelSource,
config: Optional[Dict[str, Any]] = None,
config: Optional[ModelRecordChanges] = None,
) -> ModelInstallJob:
# Add user's cached access token to HuggingFace requests
if source.access_token is None:
@ -702,7 +706,7 @@ class ModelInstallService(ModelInstallServiceBase):
def _import_from_url(
self,
source: URLModelSource,
config: Optional[Dict[str, Any]],
config: Optional[ModelRecordChanges] = None,
) -> ModelInstallJob:
remote_files, metadata = self._remote_files_from_source(source)
return self._import_remote_model(
@ -717,7 +721,7 @@ class ModelInstallService(ModelInstallServiceBase):
source: HFModelSource | URLModelSource,
remote_files: List[RemoteModelFile],
metadata: Optional[AnyModelRepoMetadata],
config: Optional[Dict[str, Any]],
config: Optional[ModelRecordChanges],
) -> ModelInstallJob:
if len(remote_files) == 0:
raise ValueError(f"{source}: No downloadable files found")
@ -730,7 +734,7 @@ class ModelInstallService(ModelInstallServiceBase):
install_job = ModelInstallJob(
id=self._next_id(),
source=source,
config_in=config or {},
config_in=config or ModelRecordChanges(),
source_metadata=metadata,
local_path=destdir, # local path may change once the download has started due to content-disposition handling
bytes=0,

View File

@ -18,6 +18,7 @@ from invokeai.backend.model_manager.config import (
ControlAdapterDefaultSettings,
MainModelDefaultSettings,
ModelFormat,
ModelSourceType,
ModelType,
ModelVariantType,
SchedulerPredictionType,
@ -66,10 +67,16 @@ class ModelRecordChanges(BaseModelExcludeNull):
"""A set of changes to apply to a model."""
# Changes applicable to all models
source: Optional[str] = Field(description="original source of the model", default=None)
source_type: Optional[ModelSourceType] = Field(description="type of model source", default=None)
source_api_response: Optional[str] = Field(description="metadata from remote source", default=None)
name: Optional[str] = Field(description="Name of the model.", default=None)
path: Optional[str] = Field(description="Path to the model.", default=None)
description: Optional[str] = Field(description="Model description", default=None)
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
type: Optional[ModelType] = Field(description="Type of model", default=None)
key: Optional[str] = Field(description="Database ID for this model", default=None)
hash: Optional[str] = Field(description="hash of model file", default=None)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
description="Default settings for this model", default=None