mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into stalker-modular_lora
This commit is contained in:
@ -6,7 +6,7 @@ import pathlib
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from fastapi import Body, Path, Query, Response, UploadFile
|
||||
from fastapi.responses import FileResponse, HTMLResponse
|
||||
@ -430,13 +430,11 @@ async def delete_model_image(
|
||||
async def install_model(
|
||||
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
|
||||
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
|
||||
# TODO(MM2): Can we type this?
|
||||
config: Optional[Dict[str, Any]] = Body(
|
||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
default=None,
|
||||
access_token: Optional[str] = Query(description="access token for the remote resource", default=None),
|
||||
config: ModelRecordChanges = Body(
|
||||
description="Object containing fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
example={"name": "string", "description": "string"},
|
||||
),
|
||||
access_token: Optional[str] = None,
|
||||
) -> ModelInstallJob:
|
||||
"""Install a model using a string identifier.
|
||||
|
||||
@ -451,8 +449,9 @@ async def install_model(
|
||||
- model/name:fp16:path/to/model.safetensors
|
||||
- model/name::path/to/model.safetensors
|
||||
|
||||
`config` is an optional dict containing model configuration values that will override
|
||||
the ones that are probed automatically.
|
||||
`config` is a ModelRecordChanges object. Fields in this object will override
|
||||
the ones that are probed automatically. Pass an empty object to accept
|
||||
all the defaults.
|
||||
|
||||
`access_token` is an optional access token for use with Urls that require
|
||||
authentication.
|
||||
@ -737,7 +736,7 @@ async def convert_model(
|
||||
# write the converted file to the convert path
|
||||
raw_model = converted_model.model
|
||||
assert hasattr(raw_model, "save_pretrained")
|
||||
raw_model.save_pretrained(convert_path)
|
||||
raw_model.save_pretrained(convert_path) # type: ignore
|
||||
assert convert_path.exists()
|
||||
|
||||
# temporarily rename the original safetensors file so that there is no naming conflict
|
||||
@ -750,12 +749,12 @@ async def convert_model(
|
||||
try:
|
||||
new_key = installer.install_path(
|
||||
convert_path,
|
||||
config={
|
||||
"name": original_name,
|
||||
"description": model_config.description,
|
||||
"hash": model_config.hash,
|
||||
"source": model_config.source,
|
||||
},
|
||||
config=ModelRecordChanges(
|
||||
name=original_name,
|
||||
description=model_config.description,
|
||||
hash=model_config.hash,
|
||||
source=model_config.source,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
|
@ -39,7 +39,7 @@ class GradientMaskOutput(BaseInvocationOutput):
|
||||
title="Create Gradient Mask",
|
||||
tags=["mask", "denoise"],
|
||||
category="latents",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class CreateGradientMaskInvocation(BaseInvocation):
|
||||
"""Creates mask for denoising model run."""
|
||||
@ -93,6 +93,7 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
|
||||
# redistribute blur so that the original edges are 0 and blur outwards to 1
|
||||
blur_tensor = (blur_tensor - 0.5) * 2
|
||||
blur_tensor[blur_tensor < 0] = 0.0
|
||||
|
||||
threshold = 1 - self.minimum_denoise
|
||||
|
||||
|
@ -37,9 +37,9 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||
ControlNetData,
|
||||
@ -60,9 +60,13 @@ from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionB
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
|
||||
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
|
||||
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
|
||||
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
|
||||
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
||||
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
||||
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
|
||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||
from invokeai.backend.stable_diffusion.extensions.t2i_adapter import T2IAdapterExt
|
||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||
@ -499,6 +503,33 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_t2i_adapter_field(
|
||||
exit_stack: ExitStack,
|
||||
context: InvocationContext,
|
||||
t2i_adapters: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
|
||||
ext_manager: ExtensionsManager,
|
||||
) -> None:
|
||||
if t2i_adapters is None:
|
||||
return
|
||||
|
||||
# Handle the possibility that t2i_adapters could be a list or a single T2IAdapterField.
|
||||
if isinstance(t2i_adapters, T2IAdapterField):
|
||||
t2i_adapters = [t2i_adapters]
|
||||
|
||||
for t2i_adapter_field in t2i_adapters:
|
||||
ext_manager.add_extension(
|
||||
T2IAdapterExt(
|
||||
node_context=context,
|
||||
model_id=t2i_adapter_field.t2i_adapter_model,
|
||||
image=context.images.get_pil(t2i_adapter_field.image.image_name),
|
||||
weight=t2i_adapter_field.weight,
|
||||
begin_step_percent=t2i_adapter_field.begin_step_percent,
|
||||
end_step_percent=t2i_adapter_field.end_step_percent,
|
||||
resize_mode=t2i_adapter_field.resize_mode,
|
||||
)
|
||||
)
|
||||
|
||||
def prep_ip_adapter_image_prompts(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
@ -708,7 +739,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
else:
|
||||
masked_latents = torch.where(mask < 0.5, 0.0, latents)
|
||||
|
||||
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
||||
return mask, masked_latents, self.denoise_mask.gradient
|
||||
|
||||
@staticmethod
|
||||
def prepare_noise_and_latents(
|
||||
@ -766,10 +797,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
dtype = TorchDevice.choose_torch_dtype()
|
||||
|
||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
if noise is not None:
|
||||
noise = noise.to(device=device, dtype=dtype)
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
|
||||
conditioning_data = self.get_conditioning_data(
|
||||
@ -802,21 +829,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
denoising_end=self.denoising_end,
|
||||
)
|
||||
|
||||
denoise_ctx = DenoiseContext(
|
||||
inputs=DenoiseInputs(
|
||||
orig_latents=latents,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
noise=noise,
|
||||
seed=seed,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
conditioning_data=conditioning_data,
|
||||
attention_processor_cls=CustomAttnProcessor2_0,
|
||||
),
|
||||
unet=None,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
# get the unet's config so that we can pass the base to sd_step_callback()
|
||||
unet_config = context.models.get_config(self.unet.unet.key)
|
||||
|
||||
@ -844,6 +856,39 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
weight=lora_field.weight,
|
||||
)
|
||||
)
|
||||
### seamless
|
||||
if self.unet.seamless_axes:
|
||||
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
|
||||
|
||||
### inpaint
|
||||
mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||
# NOTE: We used to identify inpainting models by inpecting the shape of the loaded UNet model weights. Now we
|
||||
# use the ModelVariantType config. During testing, there was a report of a user with models that had an
|
||||
# incorrect ModelVariantType value. Re-installing the model fixed the issue. If this issue turns out to be
|
||||
# prevalent, we will have to revisit how we initialize the inpainting extensions.
|
||||
if unet_config.variant == ModelVariantType.Inpaint:
|
||||
ext_manager.add_extension(InpaintModelExt(mask, masked_latents, is_gradient_mask))
|
||||
elif mask is not None:
|
||||
ext_manager.add_extension(InpaintExt(mask, is_gradient_mask))
|
||||
|
||||
# Initialize context for modular denoise
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
if noise is not None:
|
||||
noise = noise.to(device=device, dtype=dtype)
|
||||
denoise_ctx = DenoiseContext(
|
||||
inputs=DenoiseInputs(
|
||||
orig_latents=latents,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
noise=noise,
|
||||
seed=seed,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
conditioning_data=conditioning_data,
|
||||
attention_processor_cls=CustomAttnProcessor2_0,
|
||||
),
|
||||
unet=None,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
# context for loading additional models
|
||||
with ExitStack() as exit_stack:
|
||||
@ -852,6 +897,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
|
||||
# ext_manager.add_extension(ext)
|
||||
self.parse_controlnet_field(exit_stack, context, self.control, ext_manager)
|
||||
self.parse_t2i_adapter_field(exit_stack, context, self.t2i_adapter, ext_manager)
|
||||
|
||||
# ext: t2i/ip adapter
|
||||
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
||||
@ -883,6 +929,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||
|
||||
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||
# At this point, the mask ranges from 0 (leave unchanged) to 1 (inpaint).
|
||||
# We invert the mask here for compatibility with the old backend implementation.
|
||||
if mask is not None:
|
||||
mask = 1 - mask
|
||||
|
||||
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
||||
# below. Investigate whether this is appropriate.
|
||||
@ -927,7 +977,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
ExitStack() as exit_stack,
|
||||
unet_info.model_on_device() as (cached_weights, unet),
|
||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||
set_seamless(unet, self.unet.seamless_axes), # FIXME
|
||||
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_unet(
|
||||
unet,
|
||||
|
@ -24,7 +24,7 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion import set_seamless
|
||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@ -59,7 +59,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
latents = latents.to(vae.device)
|
||||
if self.fp32:
|
||||
|
@ -23,7 +23,7 @@ from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
|
||||
from invokeai.backend.tiles.utils import TBLR, Tile
|
||||
|
||||
|
||||
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.2.0")
|
||||
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.3.0")
|
||||
class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""
|
||||
|
||||
@ -36,16 +36,6 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
tile_size: int = InputField(
|
||||
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."
|
||||
)
|
||||
scale: float = InputField(
|
||||
default=4.0,
|
||||
gt=0.0,
|
||||
le=16.0,
|
||||
description="The final scale of the output image. If the model does not upscale the image, this will be ignored.",
|
||||
)
|
||||
fit_to_multiple_of_8: bool = InputField(
|
||||
default=False,
|
||||
description="If true, the output image will be resized to the nearest multiple of 8 in both dimensions.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def scale_tile(cls, tile: Tile, scale: int) -> Tile:
|
||||
@ -152,6 +142,47 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
return pil_image
|
||||
|
||||
@torch.inference_mode()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
|
||||
# revisit this.
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
# Load the model.
|
||||
spandrel_model_info = context.models.load(self.image_to_image_model)
|
||||
|
||||
# Do the upscaling.
|
||||
with spandrel_model_info as spandrel_model:
|
||||
assert isinstance(spandrel_model, SpandrelImageToImageModel)
|
||||
|
||||
# Upscale the image
|
||||
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
|
||||
|
||||
image_dto = context.images.save(image=pil_image)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"spandrel_image_to_image_autoscale",
|
||||
title="Image-to-Image (Autoscale)",
|
||||
tags=["upscale"],
|
||||
category="upscale",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
|
||||
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel) until the target scale is reached."""
|
||||
|
||||
scale: float = InputField(
|
||||
default=4.0,
|
||||
gt=0.0,
|
||||
le=16.0,
|
||||
description="The final scale of the output image. If the model does not upscale the image, this will be ignored.",
|
||||
)
|
||||
fit_to_multiple_of_8: bool = InputField(
|
||||
default=False,
|
||||
description="If true, the output image will be resized to the nearest multiple of 8 in both dimensions.",
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
@ -12,7 +12,7 @@ from invokeai.app.services.download import DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.app.services.model_records import ModelRecordChanges, ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import AnyModelConfig
|
||||
|
||||
|
||||
@ -64,7 +64,7 @@ class ModelInstallServiceBase(ABC):
|
||||
def register_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Probe and register the model at model_path.
|
||||
@ -72,7 +72,7 @@ class ModelInstallServiceBase(ABC):
|
||||
This keeps the model in its current location.
|
||||
|
||||
:param model_path: Filesystem Path to the model.
|
||||
:param config: Dict of attributes that will override autoassigned values.
|
||||
:param config: ModelRecordChanges object that will override autoassigned model record values.
|
||||
:returns id: The string ID of the registered model.
|
||||
"""
|
||||
|
||||
@ -92,7 +92,7 @@ class ModelInstallServiceBase(ABC):
|
||||
def install_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Probe, register and install the model in the models directory.
|
||||
@ -101,7 +101,7 @@ class ModelInstallServiceBase(ABC):
|
||||
the models directory handled by InvokeAI.
|
||||
|
||||
:param model_path: Filesystem Path to the model.
|
||||
:param config: Dict of attributes that will override autoassigned values.
|
||||
:param config: ModelRecordChanges object that will override autoassigned model record values.
|
||||
:returns id: The string ID of the registered model.
|
||||
"""
|
||||
|
||||
@ -109,14 +109,14 @@ class ModelInstallServiceBase(ABC):
|
||||
def heuristic_import(
|
||||
self,
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
access_token: Optional[str] = None,
|
||||
inplace: Optional[bool] = False,
|
||||
) -> ModelInstallJob:
|
||||
r"""Install the indicated model using heuristics to interpret user intentions.
|
||||
|
||||
:param source: String source
|
||||
:param config: Optional dict. Any fields in this dict
|
||||
:param config: Optional ModelRecordChanges object. Any fields in this object
|
||||
will override corresponding autoassigned probe fields in the
|
||||
model's config record as described in `import_model()`.
|
||||
:param access_token: Optional access token for remote sources.
|
||||
@ -147,7 +147,7 @@ class ModelInstallServiceBase(ABC):
|
||||
def import_model(
|
||||
self,
|
||||
source: ModelSource,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> ModelInstallJob:
|
||||
"""Install the indicated model.
|
||||
|
||||
|
@ -2,13 +2,14 @@ import re
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Set, Union
|
||||
from typing import Literal, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob
|
||||
from invokeai.app.services.model_records import ModelRecordChanges
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelSourceType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
@ -133,8 +134,9 @@ class ModelInstallJob(BaseModel):
|
||||
id: int = Field(description="Unique ID for this job")
|
||||
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
||||
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
|
||||
config_in: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
|
||||
config_in: ModelRecordChanges = Field(
|
||||
default_factory=ModelRecordChanges,
|
||||
description="Configuration information (e.g. 'description') to apply to model.",
|
||||
)
|
||||
config_out: Optional[AnyModelConfig] = Field(
|
||||
default=None, description="After successful installation, this will hold the configuration object."
|
||||
|
@ -163,26 +163,27 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def register_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
config = config or {}
|
||||
if not config.get("source"):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
config["source_type"] = ModelSourceType.Path
|
||||
config = config or ModelRecordChanges()
|
||||
if not config.source:
|
||||
config.source = model_path.resolve().as_posix()
|
||||
config.source_type = ModelSourceType.Path
|
||||
return self._register(model_path, config)
|
||||
|
||||
def install_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
config = config or {}
|
||||
config = config or ModelRecordChanges()
|
||||
info: AnyModelConfig = ModelProbe.probe(
|
||||
Path(model_path), config.model_dump(), hash_algo=self._app_config.hashing_algorithm
|
||||
) # type: ignore
|
||||
|
||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config, hash_algo=self._app_config.hashing_algorithm)
|
||||
|
||||
if preferred_name := config.get("name"):
|
||||
if preferred_name := config.name:
|
||||
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
||||
|
||||
dest_path = (
|
||||
@ -204,7 +205,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def heuristic_import(
|
||||
self,
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
access_token: Optional[str] = None,
|
||||
inplace: Optional[bool] = False,
|
||||
) -> ModelInstallJob:
|
||||
@ -216,7 +217,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source_obj.access_token = access_token
|
||||
return self.import_model(source_obj, config)
|
||||
|
||||
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
|
||||
def import_model(self, source: ModelSource, config: Optional[ModelRecordChanges] = None) -> ModelInstallJob: # noqa D102
|
||||
similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state]
|
||||
if similar_jobs:
|
||||
self._logger.warning(f"There is already an active install job for {source}. Not enqueuing.")
|
||||
@ -318,16 +319,17 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
model_path = self._app_config.models_path / model_path
|
||||
model_path = model_path.resolve()
|
||||
|
||||
config: dict[str, Any] = {}
|
||||
config["name"] = model_name
|
||||
config["description"] = stanza.get("description")
|
||||
config = ModelRecordChanges(
|
||||
name=model_name,
|
||||
description=stanza.get("description"),
|
||||
)
|
||||
legacy_config_path = stanza.get("config")
|
||||
if legacy_config_path:
|
||||
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
|
||||
legacy_config_path = self._app_config.root_path / legacy_config_path
|
||||
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
|
||||
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
|
||||
config["config_path"] = str(legacy_config_path)
|
||||
config.config_path = str(legacy_config_path)
|
||||
try:
|
||||
id = self.register_path(model_path=model_path, config=config)
|
||||
self._logger.info(f"Migrated {model_name} with id {id}")
|
||||
@ -500,11 +502,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
job.total_bytes = self._stat_size(job.local_path)
|
||||
job.bytes = job.total_bytes
|
||||
self._signal_job_running(job)
|
||||
job.config_in["source"] = str(job.source)
|
||||
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
||||
job.config_in.source = str(job.source)
|
||||
job.config_in.source_type = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
||||
# enter the metadata, if there is any
|
||||
if isinstance(job.source_metadata, (HuggingFaceMetadata)):
|
||||
job.config_in["source_api_response"] = job.source_metadata.api_response
|
||||
job.config_in.source_api_response = job.source_metadata.api_response
|
||||
|
||||
if job.inplace:
|
||||
key = self.register_path(job.local_path, job.config_in)
|
||||
@ -639,11 +641,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
return new_path
|
||||
|
||||
def _register(
|
||||
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
||||
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
|
||||
) -> str:
|
||||
config = config or {}
|
||||
config = config or ModelRecordChanges()
|
||||
|
||||
info = info or ModelProbe.probe(model_path, config, hash_algo=self._app_config.hashing_algorithm)
|
||||
info = info or ModelProbe.probe(model_path, config.model_dump(), hash_algo=self._app_config.hashing_algorithm) # type: ignore
|
||||
|
||||
model_path = model_path.resolve()
|
||||
|
||||
@ -674,11 +676,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
precision = TorchDevice.choose_torch_dtype()
|
||||
return ModelRepoVariant.FP16 if precision == torch.float16 else None
|
||||
|
||||
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
def _import_local_model(
|
||||
self, source: LocalModelSource, config: Optional[ModelRecordChanges] = None
|
||||
) -> ModelInstallJob:
|
||||
return ModelInstallJob(
|
||||
id=self._next_id(),
|
||||
source=source,
|
||||
config_in=config or {},
|
||||
config_in=config or ModelRecordChanges(),
|
||||
local_path=Path(source.path),
|
||||
inplace=source.inplace or False,
|
||||
)
|
||||
@ -686,7 +690,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def _import_from_hf(
|
||||
self,
|
||||
source: HFModelSource,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> ModelInstallJob:
|
||||
# Add user's cached access token to HuggingFace requests
|
||||
if source.access_token is None:
|
||||
@ -702,7 +706,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def _import_from_url(
|
||||
self,
|
||||
source: URLModelSource,
|
||||
config: Optional[Dict[str, Any]],
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> ModelInstallJob:
|
||||
remote_files, metadata = self._remote_files_from_source(source)
|
||||
return self._import_remote_model(
|
||||
@ -717,7 +721,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source: HFModelSource | URLModelSource,
|
||||
remote_files: List[RemoteModelFile],
|
||||
metadata: Optional[AnyModelRepoMetadata],
|
||||
config: Optional[Dict[str, Any]],
|
||||
config: Optional[ModelRecordChanges],
|
||||
) -> ModelInstallJob:
|
||||
if len(remote_files) == 0:
|
||||
raise ValueError(f"{source}: No downloadable files found")
|
||||
@ -730,7 +734,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
install_job = ModelInstallJob(
|
||||
id=self._next_id(),
|
||||
source=source,
|
||||
config_in=config or {},
|
||||
config_in=config or ModelRecordChanges(),
|
||||
source_metadata=metadata,
|
||||
local_path=destdir, # local path may change once the download has started due to content-disposition handling
|
||||
bytes=0,
|
||||
|
@ -18,6 +18,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ControlAdapterDefaultSettings,
|
||||
MainModelDefaultSettings,
|
||||
ModelFormat,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
@ -66,10 +67,16 @@ class ModelRecordChanges(BaseModelExcludeNull):
|
||||
"""A set of changes to apply to a model."""
|
||||
|
||||
# Changes applicable to all models
|
||||
source: Optional[str] = Field(description="original source of the model", default=None)
|
||||
source_type: Optional[ModelSourceType] = Field(description="type of model source", default=None)
|
||||
source_api_response: Optional[str] = Field(description="metadata from remote source", default=None)
|
||||
name: Optional[str] = Field(description="Name of the model.", default=None)
|
||||
path: Optional[str] = Field(description="Path to the model.", default=None)
|
||||
description: Optional[str] = Field(description="Model description", default=None)
|
||||
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
|
||||
type: Optional[ModelType] = Field(description="Type of model", default=None)
|
||||
key: Optional[str] = Field(description="Database ID for this model", default=None)
|
||||
hash: Optional[str] = Field(description="hash of model file", default=None)
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user