mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into stalker7779/modular_seamless
This commit is contained in:
@ -6,7 +6,7 @@ import pathlib
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from fastapi import Body, Path, Query, Response, UploadFile
|
||||
from fastapi.responses import FileResponse, HTMLResponse
|
||||
@ -430,13 +430,11 @@ async def delete_model_image(
|
||||
async def install_model(
|
||||
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
|
||||
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
|
||||
# TODO(MM2): Can we type this?
|
||||
config: Optional[Dict[str, Any]] = Body(
|
||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
default=None,
|
||||
access_token: Optional[str] = Query(description="access token for the remote resource", default=None),
|
||||
config: ModelRecordChanges = Body(
|
||||
description="Object containing fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
example={"name": "string", "description": "string"},
|
||||
),
|
||||
access_token: Optional[str] = None,
|
||||
) -> ModelInstallJob:
|
||||
"""Install a model using a string identifier.
|
||||
|
||||
@ -451,8 +449,9 @@ async def install_model(
|
||||
- model/name:fp16:path/to/model.safetensors
|
||||
- model/name::path/to/model.safetensors
|
||||
|
||||
`config` is an optional dict containing model configuration values that will override
|
||||
the ones that are probed automatically.
|
||||
`config` is a ModelRecordChanges object. Fields in this object will override
|
||||
the ones that are probed automatically. Pass an empty object to accept
|
||||
all the defaults.
|
||||
|
||||
`access_token` is an optional access token for use with Urls that require
|
||||
authentication.
|
||||
@ -737,7 +736,7 @@ async def convert_model(
|
||||
# write the converted file to the convert path
|
||||
raw_model = converted_model.model
|
||||
assert hasattr(raw_model, "save_pretrained")
|
||||
raw_model.save_pretrained(convert_path)
|
||||
raw_model.save_pretrained(convert_path) # type: ignore
|
||||
assert convert_path.exists()
|
||||
|
||||
# temporarily rename the original safetensors file so that there is no naming conflict
|
||||
@ -750,12 +749,12 @@ async def convert_model(
|
||||
try:
|
||||
new_key = installer.install_path(
|
||||
convert_path,
|
||||
config={
|
||||
"name": original_name,
|
||||
"description": model_config.description,
|
||||
"hash": model_config.hash,
|
||||
"source": model_config.source,
|
||||
},
|
||||
config=ModelRecordChanges(
|
||||
name=original_name,
|
||||
description=model_config.description,
|
||||
hash=model_config.hash,
|
||||
source=model_config.source,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
|
@ -23,7 +23,7 @@ from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
|
||||
from invokeai.backend.tiles.utils import TBLR, Tile
|
||||
|
||||
|
||||
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.2.0")
|
||||
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.3.0")
|
||||
class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""
|
||||
|
||||
@ -36,16 +36,6 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
tile_size: int = InputField(
|
||||
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."
|
||||
)
|
||||
scale: float = InputField(
|
||||
default=4.0,
|
||||
gt=0.0,
|
||||
le=16.0,
|
||||
description="The final scale of the output image. If the model does not upscale the image, this will be ignored.",
|
||||
)
|
||||
fit_to_multiple_of_8: bool = InputField(
|
||||
default=False,
|
||||
description="If true, the output image will be resized to the nearest multiple of 8 in both dimensions.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def scale_tile(cls, tile: Tile, scale: int) -> Tile:
|
||||
@ -152,6 +142,47 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
return pil_image
|
||||
|
||||
@torch.inference_mode()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
|
||||
# revisit this.
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
# Load the model.
|
||||
spandrel_model_info = context.models.load(self.image_to_image_model)
|
||||
|
||||
# Do the upscaling.
|
||||
with spandrel_model_info as spandrel_model:
|
||||
assert isinstance(spandrel_model, SpandrelImageToImageModel)
|
||||
|
||||
# Upscale the image
|
||||
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
|
||||
|
||||
image_dto = context.images.save(image=pil_image)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"spandrel_image_to_image_autoscale",
|
||||
title="Image-to-Image (Autoscale)",
|
||||
tags=["upscale"],
|
||||
category="upscale",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
|
||||
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel) until the target scale is reached."""
|
||||
|
||||
scale: float = InputField(
|
||||
default=4.0,
|
||||
gt=0.0,
|
||||
le=16.0,
|
||||
description="The final scale of the output image. If the model does not upscale the image, this will be ignored.",
|
||||
)
|
||||
fit_to_multiple_of_8: bool = InputField(
|
||||
default=False,
|
||||
description="If true, the output image will be resized to the nearest multiple of 8 in both dimensions.",
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
@ -12,7 +12,7 @@ from invokeai.app.services.download import DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.app.services.model_records import ModelRecordChanges, ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import AnyModelConfig
|
||||
|
||||
|
||||
@ -64,7 +64,7 @@ class ModelInstallServiceBase(ABC):
|
||||
def register_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Probe and register the model at model_path.
|
||||
@ -72,7 +72,7 @@ class ModelInstallServiceBase(ABC):
|
||||
This keeps the model in its current location.
|
||||
|
||||
:param model_path: Filesystem Path to the model.
|
||||
:param config: Dict of attributes that will override autoassigned values.
|
||||
:param config: ModelRecordChanges object that will override autoassigned model record values.
|
||||
:returns id: The string ID of the registered model.
|
||||
"""
|
||||
|
||||
@ -92,7 +92,7 @@ class ModelInstallServiceBase(ABC):
|
||||
def install_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Probe, register and install the model in the models directory.
|
||||
@ -101,7 +101,7 @@ class ModelInstallServiceBase(ABC):
|
||||
the models directory handled by InvokeAI.
|
||||
|
||||
:param model_path: Filesystem Path to the model.
|
||||
:param config: Dict of attributes that will override autoassigned values.
|
||||
:param config: ModelRecordChanges object that will override autoassigned model record values.
|
||||
:returns id: The string ID of the registered model.
|
||||
"""
|
||||
|
||||
@ -109,14 +109,14 @@ class ModelInstallServiceBase(ABC):
|
||||
def heuristic_import(
|
||||
self,
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
access_token: Optional[str] = None,
|
||||
inplace: Optional[bool] = False,
|
||||
) -> ModelInstallJob:
|
||||
r"""Install the indicated model using heuristics to interpret user intentions.
|
||||
|
||||
:param source: String source
|
||||
:param config: Optional dict. Any fields in this dict
|
||||
:param config: Optional ModelRecordChanges object. Any fields in this object
|
||||
will override corresponding autoassigned probe fields in the
|
||||
model's config record as described in `import_model()`.
|
||||
:param access_token: Optional access token for remote sources.
|
||||
@ -147,7 +147,7 @@ class ModelInstallServiceBase(ABC):
|
||||
def import_model(
|
||||
self,
|
||||
source: ModelSource,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> ModelInstallJob:
|
||||
"""Install the indicated model.
|
||||
|
||||
|
@ -2,13 +2,14 @@ import re
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Set, Union
|
||||
from typing import Literal, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob
|
||||
from invokeai.app.services.model_records import ModelRecordChanges
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelSourceType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
@ -133,8 +134,9 @@ class ModelInstallJob(BaseModel):
|
||||
id: int = Field(description="Unique ID for this job")
|
||||
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
||||
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
|
||||
config_in: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
|
||||
config_in: ModelRecordChanges = Field(
|
||||
default_factory=ModelRecordChanges,
|
||||
description="Configuration information (e.g. 'description') to apply to model.",
|
||||
)
|
||||
config_out: Optional[AnyModelConfig] = Field(
|
||||
default=None, description="After successful installation, this will hold the configuration object."
|
||||
|
@ -163,26 +163,27 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def register_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
config = config or {}
|
||||
if not config.get("source"):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
config["source_type"] = ModelSourceType.Path
|
||||
config = config or ModelRecordChanges()
|
||||
if not config.source:
|
||||
config.source = model_path.resolve().as_posix()
|
||||
config.source_type = ModelSourceType.Path
|
||||
return self._register(model_path, config)
|
||||
|
||||
def install_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
config = config or {}
|
||||
config = config or ModelRecordChanges()
|
||||
info: AnyModelConfig = ModelProbe.probe(
|
||||
Path(model_path), config.model_dump(), hash_algo=self._app_config.hashing_algorithm
|
||||
) # type: ignore
|
||||
|
||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config, hash_algo=self._app_config.hashing_algorithm)
|
||||
|
||||
if preferred_name := config.get("name"):
|
||||
if preferred_name := config.name:
|
||||
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
||||
|
||||
dest_path = (
|
||||
@ -204,7 +205,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def heuristic_import(
|
||||
self,
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
access_token: Optional[str] = None,
|
||||
inplace: Optional[bool] = False,
|
||||
) -> ModelInstallJob:
|
||||
@ -216,7 +217,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source_obj.access_token = access_token
|
||||
return self.import_model(source_obj, config)
|
||||
|
||||
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
|
||||
def import_model(self, source: ModelSource, config: Optional[ModelRecordChanges] = None) -> ModelInstallJob: # noqa D102
|
||||
similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state]
|
||||
if similar_jobs:
|
||||
self._logger.warning(f"There is already an active install job for {source}. Not enqueuing.")
|
||||
@ -318,16 +319,17 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
model_path = self._app_config.models_path / model_path
|
||||
model_path = model_path.resolve()
|
||||
|
||||
config: dict[str, Any] = {}
|
||||
config["name"] = model_name
|
||||
config["description"] = stanza.get("description")
|
||||
config = ModelRecordChanges(
|
||||
name=model_name,
|
||||
description=stanza.get("description"),
|
||||
)
|
||||
legacy_config_path = stanza.get("config")
|
||||
if legacy_config_path:
|
||||
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
|
||||
legacy_config_path = self._app_config.root_path / legacy_config_path
|
||||
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
|
||||
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
|
||||
config["config_path"] = str(legacy_config_path)
|
||||
config.config_path = str(legacy_config_path)
|
||||
try:
|
||||
id = self.register_path(model_path=model_path, config=config)
|
||||
self._logger.info(f"Migrated {model_name} with id {id}")
|
||||
@ -500,11 +502,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
job.total_bytes = self._stat_size(job.local_path)
|
||||
job.bytes = job.total_bytes
|
||||
self._signal_job_running(job)
|
||||
job.config_in["source"] = str(job.source)
|
||||
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
||||
job.config_in.source = str(job.source)
|
||||
job.config_in.source_type = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
||||
# enter the metadata, if there is any
|
||||
if isinstance(job.source_metadata, (HuggingFaceMetadata)):
|
||||
job.config_in["source_api_response"] = job.source_metadata.api_response
|
||||
job.config_in.source_api_response = job.source_metadata.api_response
|
||||
|
||||
if job.inplace:
|
||||
key = self.register_path(job.local_path, job.config_in)
|
||||
@ -639,11 +641,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
return new_path
|
||||
|
||||
def _register(
|
||||
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
||||
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
|
||||
) -> str:
|
||||
config = config or {}
|
||||
config = config or ModelRecordChanges()
|
||||
|
||||
info = info or ModelProbe.probe(model_path, config, hash_algo=self._app_config.hashing_algorithm)
|
||||
info = info or ModelProbe.probe(model_path, config.model_dump(), hash_algo=self._app_config.hashing_algorithm) # type: ignore
|
||||
|
||||
model_path = model_path.resolve()
|
||||
|
||||
@ -674,11 +676,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
precision = TorchDevice.choose_torch_dtype()
|
||||
return ModelRepoVariant.FP16 if precision == torch.float16 else None
|
||||
|
||||
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
def _import_local_model(
|
||||
self, source: LocalModelSource, config: Optional[ModelRecordChanges] = None
|
||||
) -> ModelInstallJob:
|
||||
return ModelInstallJob(
|
||||
id=self._next_id(),
|
||||
source=source,
|
||||
config_in=config or {},
|
||||
config_in=config or ModelRecordChanges(),
|
||||
local_path=Path(source.path),
|
||||
inplace=source.inplace or False,
|
||||
)
|
||||
@ -686,7 +690,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def _import_from_hf(
|
||||
self,
|
||||
source: HFModelSource,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> ModelInstallJob:
|
||||
# Add user's cached access token to HuggingFace requests
|
||||
if source.access_token is None:
|
||||
@ -702,7 +706,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def _import_from_url(
|
||||
self,
|
||||
source: URLModelSource,
|
||||
config: Optional[Dict[str, Any]],
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> ModelInstallJob:
|
||||
remote_files, metadata = self._remote_files_from_source(source)
|
||||
return self._import_remote_model(
|
||||
@ -717,7 +721,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source: HFModelSource | URLModelSource,
|
||||
remote_files: List[RemoteModelFile],
|
||||
metadata: Optional[AnyModelRepoMetadata],
|
||||
config: Optional[Dict[str, Any]],
|
||||
config: Optional[ModelRecordChanges],
|
||||
) -> ModelInstallJob:
|
||||
if len(remote_files) == 0:
|
||||
raise ValueError(f"{source}: No downloadable files found")
|
||||
@ -730,7 +734,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
install_job = ModelInstallJob(
|
||||
id=self._next_id(),
|
||||
source=source,
|
||||
config_in=config or {},
|
||||
config_in=config or ModelRecordChanges(),
|
||||
source_metadata=metadata,
|
||||
local_path=destdir, # local path may change once the download has started due to content-disposition handling
|
||||
bytes=0,
|
||||
|
@ -18,6 +18,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ControlAdapterDefaultSettings,
|
||||
MainModelDefaultSettings,
|
||||
ModelFormat,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
@ -66,10 +67,16 @@ class ModelRecordChanges(BaseModelExcludeNull):
|
||||
"""A set of changes to apply to a model."""
|
||||
|
||||
# Changes applicable to all models
|
||||
source: Optional[str] = Field(description="original source of the model", default=None)
|
||||
source_type: Optional[ModelSourceType] = Field(description="type of model source", default=None)
|
||||
source_api_response: Optional[str] = Field(description="metadata from remote source", default=None)
|
||||
name: Optional[str] = Field(description="Name of the model.", default=None)
|
||||
path: Optional[str] = Field(description="Path to the model.", default=None)
|
||||
description: Optional[str] = Field(description="Model description", default=None)
|
||||
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
|
||||
type: Optional[ModelType] = Field(description="Type of model", default=None)
|
||||
key: Optional[str] = Field(description="Database ID for this model", default=None)
|
||||
hash: Optional[str] = Field(description="hash of model file", default=None)
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user