mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
ONNX Support (#3562)
Note: this branch based on #3548, not on main While find out what needs to be done to implement onnx, found that I can do draft of it pretty quickly, so... here it is) Supports LoRA and TI. As example - cat with sadcatmeme lora: ![image](https://github.com/invoke-ai/InvokeAI/assets/7768370/dbd1a5df-0629-4741-94b3-8e09f4b4d5a3) ![image](https://github.com/invoke-ai/InvokeAI/assets/7768370/d918836c-fdc7-43c0-aa81-dde9182f2e0f)
This commit is contained in:
commit
81654daed7
@ -455,7 +455,7 @@ def get_torch_source() -> (Union[str, None], str):
|
||||
device = graphical_accelerator()
|
||||
|
||||
url = None
|
||||
optional_modules = None
|
||||
optional_modules = "[onnx]"
|
||||
if OS == "Linux":
|
||||
if device == "rocm":
|
||||
url = "https://download.pytorch.org/whl/rocm5.4.2"
|
||||
@ -464,7 +464,10 @@ def get_torch_source() -> (Union[str, None], str):
|
||||
|
||||
if device == "cuda":
|
||||
url = "https://download.pytorch.org/whl/cu117"
|
||||
optional_modules = "[xformers]"
|
||||
optional_modules = "[xformers,onnx-cuda]"
|
||||
if device == "cuda_and_dml":
|
||||
url = "https://download.pytorch.org/whl/cu117"
|
||||
optional_modules = "[xformers,onnx-directml]"
|
||||
|
||||
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
||||
|
||||
|
@ -167,6 +167,10 @@ def graphical_accelerator():
|
||||
"an [gold1 b]NVIDIA[/] GPU (using CUDA™)",
|
||||
"cuda",
|
||||
)
|
||||
nvidia_with_dml = (
|
||||
"an [gold1 b]NVIDIA[/] GPU (using CUDA™, and DirectML™ for ONNX) -- ALPHA",
|
||||
"cuda_and_dml",
|
||||
)
|
||||
amd = (
|
||||
"an [gold1 b]AMD[/] GPU (using ROCm™)",
|
||||
"rocm",
|
||||
@ -181,7 +185,7 @@ def graphical_accelerator():
|
||||
)
|
||||
|
||||
if OS == "Windows":
|
||||
options = [nvidia, cpu]
|
||||
options = [nvidia, nvidia_with_dml, cpu]
|
||||
if OS == "Linux":
|
||||
options = [nvidia, amd, cpu]
|
||||
elif OS == "Darwin":
|
||||
|
@ -1,6 +1,14 @@
|
||||
from typing import Literal, Optional, Union, List, Annotated
|
||||
from pydantic import BaseModel, Field
|
||||
import re
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
from .model import ClipField
|
||||
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType, ModelPatcher
|
||||
|
||||
import torch
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||
|
@ -24,6 +24,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
)
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.model_management import ModelPatcher
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
|
@ -53,6 +53,7 @@ class MainModelField(BaseModel):
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
|
||||
|
||||
class LoRAModelField(BaseModel):
|
||||
|
578
invokeai/app/invocations/onnx.py
Normal file
578
invokeai/app/invocations/onnx.py
Normal file
@ -0,0 +1,578 @@
|
||||
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
|
||||
|
||||
from contextlib import ExitStack
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import re
|
||||
import inspect
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from ...backend.model_management import ONNXModelPatcher
|
||||
from ...backend.util import choose_torch_device
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from .compel import ConditioningField
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .image import ImageOutput
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
|
||||
from tqdm import tqdm
|
||||
from .model import ClipField
|
||||
from .latent import LatentsField, LatentsOutput, build_latents_output, get_scheduler, SAMPLER_NAME_VALUES
|
||||
from .compel import CompelOutput
|
||||
|
||||
|
||||
ORT_TO_NP_TYPE = {
|
||||
"tensor(bool)": np.bool_,
|
||||
"tensor(int8)": np.int8,
|
||||
"tensor(uint8)": np.uint8,
|
||||
"tensor(int16)": np.int16,
|
||||
"tensor(uint16)": np.uint16,
|
||||
"tensor(int32)": np.int32,
|
||||
"tensor(uint32)": np.uint32,
|
||||
"tensor(int64)": np.int64,
|
||||
"tensor(uint64)": np.uint64,
|
||||
"tensor(float16)": np.float16,
|
||||
"tensor(float)": np.float32,
|
||||
"tensor(double)": np.float64,
|
||||
}
|
||||
|
||||
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
|
||||
|
||||
|
||||
class ONNXPromptInvocation(BaseInvocation):
|
||||
type: Literal["prompt_onnx"] = "prompt_onnx"
|
||||
|
||||
prompt: str = Field(default="", description="Prompt")
|
||||
clip: ClipField = Field(None, description="Clip to use")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**self.clip.tokenizer.dict(),
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**self.clip.text_encoder.dict(),
|
||||
)
|
||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
|
||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
|
||||
loras = [
|
||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||
for lora in self.clip.loras
|
||||
]
|
||||
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
# stack.enter_context(
|
||||
# context.services.model_manager.get_model(
|
||||
# model_name=name,
|
||||
# base_model=self.clip.text_encoder.base_model,
|
||||
# model_type=ModelType.TextualInversion,
|
||||
# )
|
||||
# )
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
).context.model
|
||||
)
|
||||
except Exception:
|
||||
# print(e)
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
if loras or ti_list:
|
||||
text_encoder.release_session()
|
||||
with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras), ONNXModelPatcher.apply_ti(
|
||||
orig_tokenizer, text_encoder, ti_list
|
||||
) as (tokenizer, ti_manager):
|
||||
text_encoder.create_session()
|
||||
|
||||
# copy from
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L153
|
||||
text_inputs = tokenizer(
|
||||
self.prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
"""
|
||||
untruncated_ids = tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
|
||||
|
||||
if not np.array_equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
"""
|
||||
|
||||
prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
|
||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||
context.services.latents.save(conditioning_name, (prompt_embeds, None))
|
||||
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Text to image
|
||||
class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from conditionings."""
|
||||
|
||||
type: Literal["t2l_onnx"] = "t2l_onnx"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
precision: PRECISION_VALUES = Field(default = "tensor(float16)", description="The precision to use when generating latents")
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# based on
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
if isinstance(c, torch.Tensor):
|
||||
c = c.cpu().numpy()
|
||||
if isinstance(uc, torch.Tensor):
|
||||
uc = uc.cpu().numpy()
|
||||
device = torch.device(choose_torch_device())
|
||||
prompt_embeds = np.concatenate([uc, c])
|
||||
|
||||
latents = context.services.latents.get(self.noise.latents_name)
|
||||
if isinstance(latents, torch.Tensor):
|
||||
latents = latents.cpu().numpy()
|
||||
|
||||
# TODO: better execution device handling
|
||||
latents = latents.astype(ORT_TO_NP_TYPE[self.precision])
|
||||
|
||||
# get the initial random noise unless the user supplied it
|
||||
do_classifier_free_guidance = True
|
||||
# latents_dtype = prompt_embeds.dtype
|
||||
# latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
|
||||
# if latents.shape != latents_shape:
|
||||
# raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
|
||||
def torch2numpy(latent: torch.Tensor):
|
||||
return latent.cpu().numpy()
|
||||
|
||||
def numpy2torch(latent, device):
|
||||
return torch.from_numpy(latent).to(device)
|
||||
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
scheduler.set_timesteps(self.steps)
|
||||
latents = latents * np.float64(scheduler.init_noise_sigma)
|
||||
|
||||
extra_step_kwargs = dict()
|
||||
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
extra_step_kwargs.update(
|
||||
eta=0.0,
|
||||
)
|
||||
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
|
||||
with unet_info as unet, ExitStack() as stack:
|
||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
loras = [
|
||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||
for lora in self.unet.loras
|
||||
]
|
||||
|
||||
if loras:
|
||||
unet.release_session()
|
||||
with ONNXModelPatcher.apply_lora_unet(unet, loras):
|
||||
# TODO:
|
||||
_, _, h, w = latents.shape
|
||||
unet.create_session(h, w)
|
||||
|
||||
timestep_dtype = next(
|
||||
(input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)"
|
||||
)
|
||||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||
for i in tqdm(range(len(scheduler.timesteps))):
|
||||
t = scheduler.timesteps[i]
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = scheduler.scale_model_input(numpy2torch(latent_model_input, device), t)
|
||||
latent_model_input = latent_model_input.cpu().numpy()
|
||||
|
||||
# predict the noise residual
|
||||
timestep = np.array([t], dtype=timestep_dtype)
|
||||
noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
|
||||
noise_pred = noise_pred[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
scheduler_output = scheduler.step(
|
||||
numpy2torch(noise_pred, device), t, numpy2torch(latents, device), **extra_step_kwargs
|
||||
)
|
||||
latents = torch2numpy(scheduler_output.prev_sample)
|
||||
|
||||
state = PipelineIntermediateState(
|
||||
run_id="test", step=i, timestep=timestep, latents=scheduler_output.prev_sample
|
||||
)
|
||||
dispatch_progress(self, context=context, source_node_id=source_node_id, intermediate_state=state)
|
||||
|
||||
# call the callback, if provided
|
||||
# if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=torch.from_numpy(latents))
|
||||
|
||||
|
||||
# Latent to image
|
||||
class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
type: Literal["l2i_onnx"] = "l2i_onnx"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
metadata: Optional[CoreMetadata] = Field(
|
||||
default=None, description="Optional core metadata to be written to the image"
|
||||
)
|
||||
# tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
if self.vae.vae.submodel != SubModelType.VaeDecoder:
|
||||
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
)
|
||||
|
||||
# clear memory as vae decode can request a lot
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with vae_info as vae:
|
||||
vae.create_session()
|
||||
|
||||
# copied from
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L427
|
||||
latents = 1 / 0.18215 * latents
|
||||
# image = self.vae_decoder(latent_sample=latents)[0]
|
||||
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||
image = np.concatenate([vae(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])])
|
||||
|
||||
image = np.clip(image / 2 + 0.5, 0, 1)
|
||||
image = image.transpose((0, 2, 3, 1))
|
||||
image = VaeImageProcessor.numpy_to_pil(image)[0]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class ONNXModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx"
|
||||
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
vae_decoder: VaeField = Field(default=None, description="Vae submodel")
|
||||
vae_encoder: VaeField = Field(default=None, description="Vae submodel")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class ONNXSD1ModelLoaderInvocation(BaseInvocation):
|
||||
"""Loading submodels of selected model."""
|
||||
|
||||
type: Literal["sd1_model_loader_onnx"] = "sd1_model_loader_onnx"
|
||||
|
||||
model_name: str = Field(default="", description="Model to load")
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"tags": ["model", "loader"], "type_hints": {"model_name": "model"}}, # TODO: rename to model_name?
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
||||
model_name = "stable-diffusion-v1-5"
|
||||
base_model = BaseModelType.StableDiffusion1
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=BaseModelType.StableDiffusion1,
|
||||
model_type=ModelType.ONNX,
|
||||
):
|
||||
raise Exception(f"Unkown model name: {model_name}!")
|
||||
|
||||
return ONNXModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
vae_decoder=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.VaeDecoder,
|
||||
),
|
||||
),
|
||||
vae_encoder=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.VaeEncoder,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class OnnxModelField(BaseModel):
|
||||
"""Onnx model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
|
||||
|
||||
class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
type: Literal["onnx_model_loader"] = "onnx_model_loader"
|
||||
|
||||
model: OnnxModelField = Field(description="The model to load")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Onnx Model Loader",
|
||||
"tags": ["model", "loader"],
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.ONNX
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
|
||||
"""
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Tokenizer,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.TextEncoder,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.UNet,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
"""
|
||||
|
||||
return ONNXModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae_decoder=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.VaeDecoder,
|
||||
),
|
||||
),
|
||||
vae_encoder=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.VaeEncoder,
|
||||
),
|
||||
),
|
||||
)
|
@ -12,6 +12,7 @@ from typing import Optional, List, Dict, Callable, Union, Set
|
||||
import requests
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
import onnx
|
||||
from huggingface_hub import hf_hub_url, HfFolder, HfApi
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
@ -302,8 +303,10 @@ class ModelInstall(object):
|
||||
|
||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||
staging = Path(staging)
|
||||
if "model_index.json" in files:
|
||||
if "model_index.json" in files and "unet/model.onnx" not in files:
|
||||
location = self._download_hf_pipeline(repo_id, staging) # pipeline
|
||||
elif "unet/model.onnx" in files:
|
||||
location = self._download_hf_model(repo_id, files, staging)
|
||||
else:
|
||||
for suffix in ["safetensors", "bin"]:
|
||||
if f"pytorch_lora_weights.{suffix}" in files:
|
||||
@ -368,7 +371,7 @@ class ModelInstall(object):
|
||||
model_format=info.format,
|
||||
)
|
||||
legacy_conf = None
|
||||
if info.model_type == ModelType.Main:
|
||||
if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
|
||||
attributes.update(
|
||||
dict(
|
||||
variant=info.variant_type,
|
||||
@ -433,8 +436,13 @@ class ModelInstall(object):
|
||||
location = staging / name
|
||||
paths = list()
|
||||
for filename in files:
|
||||
filePath = Path(filename)
|
||||
p = hf_download_with_resume(
|
||||
repo_id, model_dir=location, model_name=filename, access_token=self.access_token
|
||||
repo_id,
|
||||
model_dir=location / filePath.parent,
|
||||
model_name=filePath.name,
|
||||
access_token=self.access_token,
|
||||
subfolder=filePath.parent,
|
||||
)
|
||||
if p:
|
||||
paths.append(p)
|
||||
@ -482,11 +490,12 @@ def hf_download_with_resume(
|
||||
model_name: str,
|
||||
model_dest: Path = None,
|
||||
access_token: str = None,
|
||||
subfolder: str = None,
|
||||
) -> Path:
|
||||
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
url = hf_hub_url(repo_id, model_name)
|
||||
url = hf_hub_url(repo_id, model_name, subfolder=subfolder)
|
||||
|
||||
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
||||
open_mode = "wb"
|
||||
|
@ -3,6 +3,7 @@ Initialization file for invokeai.backend.model_management
|
||||
"""
|
||||
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
|
||||
from .model_cache import ModelCache
|
||||
from .lora import ModelPatcher, ONNXModelPatcher
|
||||
from .models import (
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
|
@ -6,11 +6,22 @@ from typing import Optional, Dict, Tuple, Any, Union, List
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
from onnx import numpy_helper
|
||||
from onnxruntime import OrtValue
|
||||
import numpy as np
|
||||
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from safetensors.torch import load_file
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
# TODO: rename and split this file
|
||||
|
||||
|
||||
class LoRALayerBase:
|
||||
# rank: Optional[int]
|
||||
@ -698,3 +709,186 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
new_token_ids.extend(self.pad_tokens[token_id])
|
||||
|
||||
return new_token_ids
|
||||
|
||||
|
||||
class ONNXModelPatcher:
|
||||
from .models.base import IAIOnnxRuntimeModel, OnnxRuntimeModel
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_unet(
|
||||
cls,
|
||||
unet: OnnxRuntimeModel,
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
):
|
||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_text_encoder(
|
||||
cls,
|
||||
text_encoder: OnnxRuntimeModel,
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
):
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||
yield
|
||||
|
||||
# based on
|
||||
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora(
|
||||
cls,
|
||||
model: IAIOnnxRuntimeModel,
|
||||
loras: List[Tuple[LoraModel, float]],
|
||||
prefix: str,
|
||||
):
|
||||
from .models.base import IAIOnnxRuntimeModel
|
||||
|
||||
if not isinstance(model, IAIOnnxRuntimeModel):
|
||||
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
||||
|
||||
orig_weights = dict()
|
||||
|
||||
try:
|
||||
blended_loras = dict()
|
||||
|
||||
for lora, lora_weight in loras:
|
||||
for layer_key, layer in lora.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
layer.to(dtype=torch.float32)
|
||||
layer_key = layer_key.replace(prefix, "")
|
||||
layer_weight = layer.get_weight().detach().cpu().numpy() * lora_weight
|
||||
if layer_key is blended_loras:
|
||||
blended_loras[layer_key] += layer_weight
|
||||
else:
|
||||
blended_loras[layer_key] = layer_weight
|
||||
|
||||
node_names = dict()
|
||||
for node in model.nodes.values():
|
||||
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name
|
||||
|
||||
for layer_key, lora_weight in blended_loras.items():
|
||||
conv_key = layer_key + "_Conv"
|
||||
gemm_key = layer_key + "_Gemm"
|
||||
matmul_key = layer_key + "_MatMul"
|
||||
|
||||
if conv_key in node_names or gemm_key in node_names:
|
||||
if conv_key in node_names:
|
||||
conv_node = model.nodes[node_names[conv_key]]
|
||||
else:
|
||||
conv_node = model.nodes[node_names[gemm_key]]
|
||||
|
||||
weight_name = [n for n in conv_node.input if ".weight" in n][0]
|
||||
orig_weight = model.tensors[weight_name]
|
||||
|
||||
if orig_weight.shape[-2:] == (1, 1):
|
||||
if lora_weight.shape[-2:] == (1, 1):
|
||||
new_weight = orig_weight.squeeze((3, 2)) + lora_weight.squeeze((3, 2))
|
||||
else:
|
||||
new_weight = orig_weight.squeeze((3, 2)) + lora_weight
|
||||
|
||||
new_weight = np.expand_dims(new_weight, (2, 3))
|
||||
else:
|
||||
if orig_weight.shape != lora_weight.shape:
|
||||
new_weight = orig_weight + lora_weight.reshape(orig_weight.shape)
|
||||
else:
|
||||
new_weight = orig_weight + lora_weight
|
||||
|
||||
orig_weights[weight_name] = orig_weight
|
||||
model.tensors[weight_name] = new_weight.astype(orig_weight.dtype)
|
||||
|
||||
elif matmul_key in node_names:
|
||||
weight_node = model.nodes[node_names[matmul_key]]
|
||||
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
|
||||
|
||||
orig_weight = model.tensors[matmul_name]
|
||||
new_weight = orig_weight + lora_weight.transpose()
|
||||
|
||||
orig_weights[matmul_name] = orig_weight
|
||||
model.tensors[matmul_name] = new_weight.astype(orig_weight.dtype)
|
||||
|
||||
else:
|
||||
# warn? err?
|
||||
pass
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
# restore original weights
|
||||
for name, orig_weight in orig_weights.items():
|
||||
model.tensors[name] = orig_weight
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_ti(
|
||||
cls,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: IAIOnnxRuntimeModel,
|
||||
ti_list: List[Any],
|
||||
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
||||
from .models.base import IAIOnnxRuntimeModel
|
||||
|
||||
if not isinstance(text_encoder, IAIOnnxRuntimeModel):
|
||||
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
||||
|
||||
orig_embeddings = None
|
||||
|
||||
try:
|
||||
ti_tokenizer = copy.deepcopy(tokenizer)
|
||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||
|
||||
def _get_trigger(ti, index):
|
||||
trigger = ti.name
|
||||
if index > 0:
|
||||
trigger += f"-!pad-{i}"
|
||||
return f"<{trigger}>"
|
||||
|
||||
# modify tokenizer
|
||||
new_tokens_added = 0
|
||||
for ti in ti_list:
|
||||
for i in range(ti.embedding.shape[0]):
|
||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
|
||||
|
||||
# modify text_encoder
|
||||
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
||||
|
||||
embeddings = np.concatenate(
|
||||
(np.copy(orig_embeddings), np.zeros((new_tokens_added, orig_embeddings.shape[1]))),
|
||||
axis=0,
|
||||
)
|
||||
|
||||
for ti in ti_list:
|
||||
ti_tokens = []
|
||||
for i in range(ti.embedding.shape[0]):
|
||||
embedding = ti.embedding[i].detach().numpy()
|
||||
trigger = _get_trigger(ti, i)
|
||||
|
||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||
if token_id == ti_tokenizer.unk_token_id:
|
||||
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
|
||||
|
||||
if embeddings[token_id].shape != embedding.shape:
|
||||
raise ValueError(
|
||||
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {embeddings[token_id].shape[0]}."
|
||||
)
|
||||
|
||||
embeddings[token_id] = embedding
|
||||
ti_tokens.append(token_id)
|
||||
|
||||
if len(ti_tokens) > 1:
|
||||
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
|
||||
|
||||
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = embeddings.astype(
|
||||
orig_embeddings.dtype
|
||||
)
|
||||
|
||||
yield ti_tokenizer, ti_manager
|
||||
|
||||
finally:
|
||||
# restore
|
||||
if orig_embeddings is not None:
|
||||
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = orig_embeddings
|
||||
|
@ -360,7 +360,8 @@ class ModelCache(object):
|
||||
# 2 refs:
|
||||
# 1 from cache_entry
|
||||
# 1 from getrefcount function
|
||||
if not cache_entry.locked and refs <= 2:
|
||||
# 1 from onnx runtime object
|
||||
if not cache_entry.locked and refs <= 3 if "onnx" in model_key else 2:
|
||||
self.logger.debug(
|
||||
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||
)
|
||||
|
@ -27,7 +27,7 @@ class ModelProbeInfo(object):
|
||||
variant_type: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
format: Literal["diffusers", "checkpoint", "lycoris"]
|
||||
format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"]
|
||||
image_size: int
|
||||
|
||||
|
||||
@ -41,6 +41,7 @@ class ModelProbe(object):
|
||||
PROBES = {
|
||||
"diffusers": {},
|
||||
"checkpoint": {},
|
||||
"onnx": {},
|
||||
}
|
||||
|
||||
CLASS2TYPE = {
|
||||
@ -53,7 +54,9 @@ class ModelProbe(object):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_probe(cls, format: Literal["diffusers", "checkpoint"], model_type: ModelType, probe_class: ProbeBase):
|
||||
def register_probe(
|
||||
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase
|
||||
):
|
||||
cls.PROBES[format][model_type] = probe_class
|
||||
|
||||
@classmethod
|
||||
@ -95,6 +98,7 @@ class ModelProbe(object):
|
||||
if format_type == "diffusers"
|
||||
else cls.get_model_type_from_checkpoint(model_path, model)
|
||||
)
|
||||
format_type = "onnx" if model_type == ModelType.ONNX else format_type
|
||||
probe_class = cls.PROBES[format_type].get(model_type)
|
||||
if not probe_class:
|
||||
return None
|
||||
@ -168,6 +172,8 @@ class ModelProbe(object):
|
||||
if model:
|
||||
class_name = model.__class__.__name__
|
||||
else:
|
||||
if (folder_path / "unet/model.onnx").exists():
|
||||
return ModelType.ONNX
|
||||
if (folder_path / "learned_embeds.bin").exists():
|
||||
return ModelType.TextualInversion
|
||||
|
||||
@ -460,6 +466,17 @@ class TextualInversionFolderProbe(FolderProbeBase):
|
||||
return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type()
|
||||
|
||||
|
||||
class ONNXFolderProbe(FolderProbeBase):
|
||||
def get_format(self) -> str:
|
||||
return "onnx"
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
return ModelVariantType.Normal
|
||||
|
||||
|
||||
class ControlNetFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
config_file = self.folder_path / "config.json"
|
||||
@ -497,3 +514,4 @@ ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
||||
|
@ -23,8 +23,11 @@ from .lora import LoRAModel
|
||||
from .controlnet import ControlNetModel # TODO:
|
||||
from .textual_inversion import TextualInversionModel
|
||||
|
||||
from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model
|
||||
|
||||
MODEL_CLASSES = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelType.ONNX: ONNXStableDiffusion1Model,
|
||||
ModelType.Main: StableDiffusion1Model,
|
||||
ModelType.Vae: VaeModel,
|
||||
ModelType.Lora: LoRAModel,
|
||||
@ -32,6 +35,7 @@ MODEL_CLASSES = {
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||
ModelType.Main: StableDiffusion2Model,
|
||||
ModelType.Vae: VaeModel,
|
||||
ModelType.Lora: LoRAModel,
|
||||
@ -45,6 +49,7 @@ MODEL_CLASSES = {
|
||||
ModelType.Lora: LoRAModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||
},
|
||||
BaseModelType.StableDiffusionXLRefiner: {
|
||||
ModelType.Main: StableDiffusionXLModel,
|
||||
@ -53,6 +58,7 @@ MODEL_CLASSES = {
|
||||
ModelType.Lora: LoRAModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||
},
|
||||
# BaseModelType.Kandinsky2_1: {
|
||||
# ModelType.Main: Kandinsky2_1Model,
|
||||
|
@ -8,13 +8,23 @@ from abc import ABCMeta, abstractmethod
|
||||
from pathlib import Path
|
||||
from picklescan.scanner import scan_file_path
|
||||
import torch
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
from diffusers import DiffusionPipeline, ConfigMixin
|
||||
from pathlib import Path
|
||||
from diffusers import DiffusionPipeline, ConfigMixin, OnnxRuntimeModel
|
||||
|
||||
from contextlib import suppress
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||
|
||||
import onnx
|
||||
from onnx import numpy_helper
|
||||
from onnxruntime import (
|
||||
InferenceSession,
|
||||
SessionOptions,
|
||||
get_available_providers,
|
||||
)
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
pass
|
||||
@ -37,6 +47,7 @@ class BaseModelType(str, Enum):
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
ONNX = "onnx"
|
||||
Main = "main"
|
||||
Vae = "vae"
|
||||
Lora = "lora"
|
||||
@ -51,6 +62,8 @@ class SubModelType(str, Enum):
|
||||
Tokenizer = "tokenizer"
|
||||
Tokenizer2 = "tokenizer_2"
|
||||
Vae = "vae"
|
||||
VaeDecoder = "vae_decoder"
|
||||
VaeEncoder = "vae_encoder"
|
||||
Scheduler = "scheduler"
|
||||
SafetyChecker = "safety_checker"
|
||||
# MoVQ = "movq"
|
||||
@ -362,6 +375,8 @@ def calc_model_size_by_data(model) -> int:
|
||||
return _calc_pipeline_by_data(model)
|
||||
elif isinstance(model, torch.nn.Module):
|
||||
return _calc_model_by_data(model)
|
||||
elif isinstance(model, IAIOnnxRuntimeModel):
|
||||
return _calc_onnx_model_by_data(model)
|
||||
else:
|
||||
return 0
|
||||
|
||||
@ -382,6 +397,12 @@ def _calc_model_by_data(model) -> int:
|
||||
return mem
|
||||
|
||||
|
||||
def _calc_onnx_model_by_data(model) -> int:
|
||||
tensor_size = model.tensors.size() * 2 # The session doubles this
|
||||
mem = tensor_size # in bytes
|
||||
return mem
|
||||
|
||||
|
||||
def _fast_safetensors_reader(path: str):
|
||||
checkpoint = dict()
|
||||
device = torch.device("meta")
|
||||
@ -449,3 +470,208 @@ class SilenceWarnings(object):
|
||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
||||
warnings.simplefilter("default")
|
||||
|
||||
|
||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||
|
||||
|
||||
class IAIOnnxRuntimeModel:
|
||||
class _tensor_access:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
self.indexes = dict()
|
||||
for idx, obj in enumerate(self.model.proto.graph.initializer):
|
||||
self.indexes[obj.name] = idx
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
value = self.model.proto.graph.initializer[self.indexes[key]]
|
||||
return numpy_helper.to_array(value)
|
||||
|
||||
def __setitem__(self, key: str, value: np.ndarray):
|
||||
new_node = numpy_helper.from_array(value)
|
||||
# set_external_data(new_node, location="in-memory-location")
|
||||
new_node.name = key
|
||||
# new_node.ClearField("raw_data")
|
||||
del self.model.proto.graph.initializer[self.indexes[key]]
|
||||
self.model.proto.graph.initializer.insert(self.indexes[key], new_node)
|
||||
# self.model.data[key] = OrtValue.ortvalue_from_numpy(value)
|
||||
|
||||
# __delitem__
|
||||
|
||||
def __contains__(self, key: str):
|
||||
return self.indexes[key] in self.model.proto.graph.initializer
|
||||
|
||||
def items(self):
|
||||
raise NotImplementedError("tensor.items")
|
||||
# return [(obj.name, obj) for obj in self.raw_proto]
|
||||
|
||||
def keys(self):
|
||||
return self.indexes.keys()
|
||||
|
||||
def values(self):
|
||||
raise NotImplementedError("tensor.values")
|
||||
# return [obj for obj in self.raw_proto]
|
||||
|
||||
def size(self):
|
||||
bytesSum = 0
|
||||
for node in self.model.proto.graph.initializer:
|
||||
bytesSum += sys.getsizeof(node.raw_data)
|
||||
return bytesSum
|
||||
|
||||
class _access_helper:
|
||||
def __init__(self, raw_proto):
|
||||
self.indexes = dict()
|
||||
self.raw_proto = raw_proto
|
||||
for idx, obj in enumerate(raw_proto):
|
||||
self.indexes[obj.name] = idx
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
return self.raw_proto[self.indexes[key]]
|
||||
|
||||
def __setitem__(self, key: str, value):
|
||||
index = self.indexes[key]
|
||||
del self.raw_proto[index]
|
||||
self.raw_proto.insert(index, value)
|
||||
|
||||
# __delitem__
|
||||
|
||||
def __contains__(self, key: str):
|
||||
return key in self.indexes
|
||||
|
||||
def items(self):
|
||||
return [(obj.name, obj) for obj in self.raw_proto]
|
||||
|
||||
def keys(self):
|
||||
return self.indexes.keys()
|
||||
|
||||
def values(self):
|
||||
return [obj for obj in self.raw_proto]
|
||||
|
||||
def __init__(self, model_path: str, provider: Optional[str]):
|
||||
self.path = model_path
|
||||
self.session = None
|
||||
self.provider = provider
|
||||
"""
|
||||
self.data_path = self.path + "_data"
|
||||
if not os.path.exists(self.data_path):
|
||||
print(f"Moving model tensors to separate file: {self.data_path}")
|
||||
tmp_proto = onnx.load(model_path, load_external_data=True)
|
||||
onnx.save_model(tmp_proto, self.path, save_as_external_data=True, all_tensors_to_one_file=True, location=os.path.basename(self.data_path), size_threshold=1024, convert_attribute=False)
|
||||
del tmp_proto
|
||||
gc.collect()
|
||||
|
||||
self.proto = onnx.load(model_path, load_external_data=False)
|
||||
"""
|
||||
|
||||
self.proto = onnx.load(model_path, load_external_data=True)
|
||||
# self.data = dict()
|
||||
# for tensor in self.proto.graph.initializer:
|
||||
# name = tensor.name
|
||||
|
||||
# if tensor.HasField("raw_data"):
|
||||
# npt = numpy_helper.to_array(tensor)
|
||||
# orv = OrtValue.ortvalue_from_numpy(npt)
|
||||
# # self.data[name] = orv
|
||||
# # set_external_data(tensor, location="in-memory-location")
|
||||
# tensor.name = name
|
||||
# # tensor.ClearField("raw_data")
|
||||
|
||||
self.nodes = self._access_helper(self.proto.graph.node)
|
||||
# self.initializers = self._access_helper(self.proto.graph.initializer)
|
||||
# print(self.proto.graph.input)
|
||||
# print(self.proto.graph.initializer)
|
||||
|
||||
self.tensors = self._tensor_access(self)
|
||||
|
||||
# TODO: integrate with model manager/cache
|
||||
def create_session(self, height=None, width=None):
|
||||
if self.session is None or self.session_width != width or self.session_height != height:
|
||||
# onnx.save(self.proto, "tmp.onnx")
|
||||
# onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
|
||||
# TODO: something to be able to get weight when they already moved outside of model proto
|
||||
# (trimmed_model, external_data) = buffer_external_data_tensors(self.proto)
|
||||
sess = SessionOptions()
|
||||
# self._external_data.update(**external_data)
|
||||
# sess.add_external_initializers(list(self.data.keys()), list(self.data.values()))
|
||||
# sess.enable_profiling = True
|
||||
|
||||
# sess.intra_op_num_threads = 1
|
||||
# sess.inter_op_num_threads = 1
|
||||
# sess.execution_mode = ExecutionMode.ORT_SEQUENTIAL
|
||||
# sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
# sess.enable_cpu_mem_arena = True
|
||||
# sess.enable_mem_pattern = True
|
||||
# sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code
|
||||
self.session_height = height
|
||||
self.session_width = width
|
||||
if height and width:
|
||||
sess.add_free_dimension_override_by_name("unet_sample_batch", 2)
|
||||
sess.add_free_dimension_override_by_name("unet_sample_channels", 4)
|
||||
sess.add_free_dimension_override_by_name("unet_hidden_batch", 2)
|
||||
sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
|
||||
sess.add_free_dimension_override_by_name("unet_sample_height", self.session_height)
|
||||
sess.add_free_dimension_override_by_name("unet_sample_width", self.session_width)
|
||||
sess.add_free_dimension_override_by_name("unet_time_batch", 1)
|
||||
providers = []
|
||||
if self.provider:
|
||||
providers.append(self.provider)
|
||||
else:
|
||||
providers = get_available_providers()
|
||||
if "TensorrtExecutionProvider" in providers:
|
||||
providers.remove("TensorrtExecutionProvider")
|
||||
try:
|
||||
self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess)
|
||||
except Exception as e:
|
||||
raise e
|
||||
# self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
|
||||
# self.io_binding = self.session.io_binding()
|
||||
|
||||
def release_session(self):
|
||||
self.session = None
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
return
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
if self.session is None:
|
||||
raise Exception("You should call create_session before running model")
|
||||
|
||||
inputs = {k: np.array(v) for k, v in kwargs.items()}
|
||||
output_names = self.session.get_outputs()
|
||||
# for k in inputs:
|
||||
# self.io_binding.bind_cpu_input(k, inputs[k])
|
||||
# for name in output_names:
|
||||
# self.io_binding.bind_output(name.name)
|
||||
# self.session.run_with_iobinding(self.io_binding, None)
|
||||
# return self.io_binding.copy_outputs_to_cpu()
|
||||
return self.session.run(None, inputs)
|
||||
|
||||
# compatability with diffusers load code
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_id: Union[str, Path],
|
||||
subfolder: Union[str, Path] = None,
|
||||
file_name: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
sess_options: Optional["SessionOptions"] = None,
|
||||
**kwargs,
|
||||
):
|
||||
file_name = file_name or ONNX_WEIGHTS_NAME
|
||||
|
||||
if os.path.isdir(model_id):
|
||||
model_path = model_id
|
||||
if subfolder is not None:
|
||||
model_path = os.path.join(model_path, subfolder)
|
||||
model_path = os.path.join(model_path, file_name)
|
||||
|
||||
else:
|
||||
model_path = model_id
|
||||
|
||||
# load model from local directory
|
||||
if not os.path.isfile(model_path):
|
||||
raise Exception(f"Model not found: {model_path}")
|
||||
|
||||
# TODO: session options
|
||||
return cls(model_path, provider=provider)
|
||||
|
@ -0,0 +1,157 @@
|
||||
import os
|
||||
import json
|
||||
from enum import Enum
|
||||
from pydantic import Field
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
from .base import (
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
ModelVariantType,
|
||||
DiffusersModel,
|
||||
SchedulerPredictionType,
|
||||
SilenceWarnings,
|
||||
read_checkpoint_meta,
|
||||
classproperty,
|
||||
OnnxRuntimeModel,
|
||||
IAIOnnxRuntimeModel,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
|
||||
class StableDiffusionOnnxModelFormat(str, Enum):
|
||||
Olive = "olive"
|
||||
Onnx = "onnx"
|
||||
|
||||
|
||||
class ONNXStableDiffusion1Model(DiffusersModel):
|
||||
class Config(ModelConfigBase):
|
||||
model_format: Literal[StableDiffusionOnnxModelFormat.Onnx]
|
||||
variant: ModelVariantType
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model == BaseModelType.StableDiffusion1
|
||||
assert model_type == ModelType.ONNX
|
||||
super().__init__(
|
||||
model_path=model_path,
|
||||
base_model=BaseModelType.StableDiffusion1,
|
||||
model_type=ModelType.ONNX,
|
||||
)
|
||||
|
||||
for child_name, child_type in self.child_types.items():
|
||||
if child_type is OnnxRuntimeModel:
|
||||
self.child_types[child_name] = IAIOnnxRuntimeModel
|
||||
|
||||
# TODO: check that no optimum models provided
|
||||
|
||||
@classmethod
|
||||
def probe_config(cls, path: str, **kwargs):
|
||||
model_format = cls.detect_format(path)
|
||||
in_channels = 4 # TODO:
|
||||
|
||||
if in_channels == 9:
|
||||
variant = ModelVariantType.Inpaint
|
||||
elif in_channels == 4:
|
||||
variant = ModelVariantType.Normal
|
||||
else:
|
||||
raise Exception("Unkown stable diffusion 1.* model format")
|
||||
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
model_format=model_format,
|
||||
variant=variant,
|
||||
)
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
# TODO: Detect onnx vs olive
|
||||
return StableDiffusionOnnxModelFormat.Onnx
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
return model_path
|
||||
|
||||
|
||||
class ONNXStableDiffusion2Model(DiffusersModel):
|
||||
# TODO: check that configs overwriten properly
|
||||
class Config(ModelConfigBase):
|
||||
model_format: Literal[StableDiffusionOnnxModelFormat.Onnx]
|
||||
variant: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model == BaseModelType.StableDiffusion2
|
||||
assert model_type == ModelType.ONNX
|
||||
super().__init__(
|
||||
model_path=model_path,
|
||||
base_model=BaseModelType.StableDiffusion2,
|
||||
model_type=ModelType.ONNX,
|
||||
)
|
||||
|
||||
for child_name, child_type in self.child_types.items():
|
||||
if child_type is OnnxRuntimeModel:
|
||||
self.child_types[child_name] = IAIOnnxRuntimeModel
|
||||
# TODO: check that no optimum models provided
|
||||
|
||||
@classmethod
|
||||
def probe_config(cls, path: str, **kwargs):
|
||||
model_format = cls.detect_format(path)
|
||||
in_channels = 4 # TODO:
|
||||
|
||||
if in_channels == 9:
|
||||
variant = ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
variant = ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
variant = ModelVariantType.Normal
|
||||
else:
|
||||
raise Exception("Unkown stable diffusion 2.* model format")
|
||||
|
||||
if variant == ModelVariantType.Normal:
|
||||
prediction_type = SchedulerPredictionType.VPrediction
|
||||
upcast_attention = True
|
||||
|
||||
else:
|
||||
prediction_type = SchedulerPredictionType.Epsilon
|
||||
upcast_attention = False
|
||||
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
model_format=model_format,
|
||||
variant=variant,
|
||||
prediction_type=prediction_type,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
# TODO: Detect onnx vs olive
|
||||
return StableDiffusionOnnxModelFormat.Onnx
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
return model_path
|
169
invokeai/frontend/web/dist/assets/App-44cdaaf3.js
vendored
Normal file
169
invokeai/frontend/web/dist/assets/App-44cdaaf3.js
vendored
Normal file
File diff suppressed because one or more lines are too long
169
invokeai/frontend/web/dist/assets/App-ea7b7298.js
vendored
169
invokeai/frontend/web/dist/assets/App-ea7b7298.js
vendored
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -1,4 +1,4 @@
|
||||
import{A as m,f_ as Je,z as y,a4 as Ka,f$ as Xa,af as va,aj as d,g0 as b,g1 as t,g2 as Ya,g3 as h,g4 as ua,g5 as Ja,g6 as Qa,aI as Za,g7 as et,ad as rt,g8 as at}from"./index-9bb68e3a.js";import{s as fa,n as o,t as tt,o as ha,p as ot,q as ma,v as ga,w as ya,x as it,y as Sa,z as pa,A as xr,B as nt,D as lt,E as st,F as xa,G as $a,H as ka,J as dt,K as _a,L as ct,M as bt,N as vt,O as ut,Q as wa,R as ft,S as ht,T as mt,U as gt,V as yt,W as St,e as pt,X as xt}from"./MantineProvider-ae002ae6.js";var za=String.raw,Ca=za`
|
||||
import{A as m,f$ as Je,z as y,a4 as Ka,g0 as Xa,af as va,aj as d,g1 as b,g2 as t,g3 as Ya,g4 as h,g5 as ua,g6 as Ja,g7 as Qa,aI as Za,g8 as et,ad as rt,g9 as at}from"./index-18f2f740.js";import{s as fa,n as o,t as tt,o as ha,p as ot,q as ma,v as ga,w as ya,x as it,y as Sa,z as pa,A as xr,B as nt,D as lt,E as st,F as xa,G as $a,H as ka,J as dt,K as _a,L as ct,M as bt,N as vt,O as ut,Q as wa,R as ft,S as ht,T as mt,U as gt,V as yt,W as St,e as pt,X as xt}from"./MantineProvider-b20a2267.js";var za=String.raw,Ca=za`
|
||||
:root,
|
||||
:host {
|
||||
--chakra-vh: 100vh;
|
125
invokeai/frontend/web/dist/assets/index-18f2f740.js
vendored
Normal file
125
invokeai/frontend/web/dist/assets/index-18f2f740.js
vendored
Normal file
File diff suppressed because one or more lines are too long
125
invokeai/frontend/web/dist/assets/index-9bb68e3a.js
vendored
125
invokeai/frontend/web/dist/assets/index-9bb68e3a.js
vendored
File diff suppressed because one or more lines are too long
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@ -12,7 +12,7 @@
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
<script type="module" crossorigin src="./assets/index-9bb68e3a.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index-18f2f740.js"></script>
|
||||
</head>
|
||||
|
||||
<body dir="ltr">
|
||||
|
2
invokeai/frontend/web/dist/locales/en.json
vendored
2
invokeai/frontend/web/dist/locales/en.json
vendored
@ -342,6 +342,8 @@
|
||||
"diffusersModels": "Diffusers",
|
||||
"loraModels": "LoRAs",
|
||||
"safetensorModels": "SafeTensors",
|
||||
"onnxModels": "Onnx",
|
||||
"oliveModels": "Olives",
|
||||
"modelAdded": "Model Added",
|
||||
"modelUpdated": "Model Updated",
|
||||
"modelUpdateFailed": "Model Update Failed",
|
||||
|
@ -342,6 +342,8 @@
|
||||
"diffusersModels": "Diffusers",
|
||||
"loraModels": "LoRAs",
|
||||
"safetensorModels": "SafeTensors",
|
||||
"onnxModels": "Onnx",
|
||||
"oliveModels": "Olives",
|
||||
"modelAdded": "Model Added",
|
||||
"modelUpdated": "Model Updated",
|
||||
"modelUpdateFailed": "Model Update Failed",
|
||||
|
@ -36,7 +36,8 @@ export const addModelsLoadedListener = () => {
|
||||
action.payload.entities,
|
||||
(m) =>
|
||||
m?.model_name === currentModel?.model_name &&
|
||||
m?.base_model === currentModel?.base_model
|
||||
m?.base_model === currentModel?.base_model &&
|
||||
m?.model_type === currentModel?.model_type
|
||||
);
|
||||
|
||||
if (isCurrentModelAvailable) {
|
||||
@ -83,7 +84,8 @@ export const addModelsLoadedListener = () => {
|
||||
action.payload.entities,
|
||||
(m) =>
|
||||
m?.model_name === currentModel?.model_name &&
|
||||
m?.base_model === currentModel?.base_model
|
||||
m?.base_model === currentModel?.base_model &&
|
||||
m?.model_type === currentModel?.model_type
|
||||
);
|
||||
|
||||
if (isCurrentModelAvailable) {
|
||||
|
@ -47,9 +47,9 @@ export const addTabChangedListener = () => {
|
||||
}
|
||||
|
||||
// only store the model name and base model in redux
|
||||
const { base_model, model_name } = firstValidCanvasModel;
|
||||
const { base_model, model_name, model_type } = firstValidCanvasModel;
|
||||
|
||||
dispatch(modelChanged({ base_model, model_name }));
|
||||
dispatch(modelChanged({ base_model, model_name, model_type }));
|
||||
}
|
||||
},
|
||||
});
|
||||
|
@ -14,8 +14,11 @@ import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
useGetMainModelsQuery,
|
||||
useGetOnnxModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { useFeatureStatus } from '../../../system/hooks/useFeatureStatus';
|
||||
|
||||
@ -28,6 +31,7 @@ const ModelInputFieldComponent = (
|
||||
const { t } = useTranslation();
|
||||
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
||||
|
||||
const { data: onnxModels } = useGetOnnxModelsQuery(NON_REFINER_BASE_MODELS);
|
||||
const { data: mainModels, isLoading } = useGetMainModelsQuery(
|
||||
NON_REFINER_BASE_MODELS
|
||||
);
|
||||
@ -51,17 +55,39 @@ const ModelInputFieldComponent = (
|
||||
});
|
||||
});
|
||||
|
||||
if (onnxModels) {
|
||||
forEach(onnxModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.model_name,
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
}
|
||||
return data;
|
||||
}, [mainModels]);
|
||||
}, [mainModels, onnxModels]);
|
||||
|
||||
// grab the full model entity from the RTK Query cache
|
||||
// TODO: maybe we should just store the full model entity in state?
|
||||
const selectedModel = useMemo(
|
||||
() =>
|
||||
mainModels?.entities[
|
||||
(mainModels?.entities[
|
||||
`${field.value?.base_model}/main/${field.value?.model_name}`
|
||||
] ?? null,
|
||||
[field.value?.base_model, field.value?.model_name, mainModels?.entities]
|
||||
] ||
|
||||
onnxModels?.entities[
|
||||
`${field.value?.base_model}/onnx/${field.value?.model_name}`
|
||||
]) ??
|
||||
null,
|
||||
[
|
||||
field.value?.base_model,
|
||||
field.value?.model_name,
|
||||
mainModels?.entities,
|
||||
onnxModels?.entities,
|
||||
]
|
||||
);
|
||||
|
||||
const handleChangeModel = useCallback(
|
||||
|
@ -9,6 +9,7 @@ import {
|
||||
CLIP_SKIP,
|
||||
LORA_LOADER,
|
||||
MAIN_MODEL_LOADER,
|
||||
ONNX_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
POSITIVE_CONDITIONING,
|
||||
@ -17,7 +18,8 @@ import {
|
||||
export const addLoRAsToGraph = (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string
|
||||
baseNodeId: string,
|
||||
modelLoader: string = MAIN_MODEL_LOADER
|
||||
): void => {
|
||||
/**
|
||||
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
|
||||
@ -40,6 +42,10 @@ export const addLoRAsToGraph = (
|
||||
!(
|
||||
e.source.node_id === MAIN_MODEL_LOADER &&
|
||||
['unet'].includes(e.source.field)
|
||||
) &&
|
||||
!(
|
||||
e.source.node_id === ONNX_MODEL_LOADER &&
|
||||
['unet'].includes(e.source.field)
|
||||
)
|
||||
);
|
||||
// Remove CLIP_SKIP connections to conditionings to feed it through LoRAs
|
||||
@ -75,12 +81,11 @@ export const addLoRAsToGraph = (
|
||||
|
||||
// add to graph
|
||||
graph.nodes[currentLoraNodeId] = loraLoaderNode;
|
||||
|
||||
if (currentLoraIndex === 0) {
|
||||
// first lora = start the lora chain, attach directly to model loader
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
node_id: modelLoader,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
|
@ -9,13 +9,15 @@ import {
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
ONNX_MODEL_LOADER,
|
||||
TEXT_TO_IMAGE_GRAPH,
|
||||
VAE_LOADER,
|
||||
} from './constants';
|
||||
|
||||
export const addVAEToGraph = (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph
|
||||
graph: NonNullableGraph,
|
||||
modelLoader: string = MAIN_MODEL_LOADER
|
||||
): void => {
|
||||
const { vae } = state.generation;
|
||||
|
||||
@ -32,12 +34,12 @@ export const addVAEToGraph = (
|
||||
vae_model: vae,
|
||||
};
|
||||
}
|
||||
|
||||
const isOnnxModel = modelLoader == ONNX_MODEL_LOADER;
|
||||
if (graph.id === TEXT_TO_IMAGE_GRAPH || graph.id === IMAGE_TO_IMAGE_GRAPH) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
|
||||
field: 'vae',
|
||||
node_id: isAutoVae ? modelLoader : VAE_LOADER,
|
||||
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
@ -49,8 +51,8 @@ export const addVAEToGraph = (
|
||||
if (graph.id === IMAGE_TO_IMAGE_GRAPH) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
|
||||
field: 'vae',
|
||||
node_id: isAutoVae ? modelLoader : VAE_LOADER,
|
||||
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
@ -62,8 +64,8 @@ export const addVAEToGraph = (
|
||||
if (graph.id === INPAINT_GRAPH) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
|
||||
field: 'vae',
|
||||
node_id: isAutoVae ? modelLoader : VAE_LOADER,
|
||||
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT,
|
||||
|
@ -12,6 +12,7 @@ import {
|
||||
CLIP_SKIP,
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
ONNX_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
@ -52,7 +53,8 @@ export const buildCanvasTextToImageGraph = (
|
||||
const use_cpu = shouldUseNoiseSettings
|
||||
? shouldUseCpuNoise
|
||||
: initialGenerationState.shouldUseCpuNoise;
|
||||
|
||||
const onnx_model_type = model.model_type.includes('onnx');
|
||||
const model_loader = onnx_model_type ? ONNX_MODEL_LOADER : MAIN_MODEL_LOADER;
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||
@ -63,17 +65,18 @@ export const buildCanvasTextToImageGraph = (
|
||||
*/
|
||||
|
||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||
// TODO: Actually create the graph correctly for ONNX
|
||||
const graph: NonNullableGraph = {
|
||||
id: TEXT_TO_IMAGE_GRAPH,
|
||||
nodes: {
|
||||
[POSITIVE_CONDITIONING]: {
|
||||
type: 'compel',
|
||||
type: onnx_model_type ? 'prompt_onnx' : 'compel',
|
||||
id: POSITIVE_CONDITIONING,
|
||||
is_intermediate: true,
|
||||
prompt: positivePrompt,
|
||||
},
|
||||
[NEGATIVE_CONDITIONING]: {
|
||||
type: 'compel',
|
||||
type: onnx_model_type ? 'prompt_onnx' : 'compel',
|
||||
id: NEGATIVE_CONDITIONING,
|
||||
is_intermediate: true,
|
||||
prompt: negativePrompt,
|
||||
@ -87,16 +90,16 @@ export const buildCanvasTextToImageGraph = (
|
||||
use_cpu,
|
||||
},
|
||||
[TEXT_TO_LATENTS]: {
|
||||
type: 't2l',
|
||||
type: onnx_model_type ? 't2l_onnx' : 't2l',
|
||||
id: TEXT_TO_LATENTS,
|
||||
is_intermediate: true,
|
||||
cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
},
|
||||
[MAIN_MODEL_LOADER]: {
|
||||
type: 'main_model_loader',
|
||||
id: MAIN_MODEL_LOADER,
|
||||
[model_loader]: {
|
||||
type: model_loader,
|
||||
id: model_loader,
|
||||
is_intermediate: true,
|
||||
model,
|
||||
},
|
||||
@ -107,7 +110,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
skipped_layers: clipSkip,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
type: onnx_model_type ? 'l2i_onnx' : 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
is_intermediate: !shouldAutoSave,
|
||||
},
|
||||
@ -135,7 +138,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
node_id: model_loader,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
@ -165,7 +168,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
node_id: model_loader,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
@ -229,10 +232,10 @@ export const buildCanvasTextToImageGraph = (
|
||||
});
|
||||
|
||||
// add LoRA support
|
||||
addLoRAsToGraph(state, graph, TEXT_TO_LATENTS);
|
||||
addLoRAsToGraph(state, graph, TEXT_TO_LATENTS, model_loader);
|
||||
|
||||
// optionally add custom VAE
|
||||
addVAEToGraph(state, graph);
|
||||
addVAEToGraph(state, graph, model_loader);
|
||||
|
||||
// add dynamic prompts - also sets up core iteration and seed
|
||||
addDynamicPromptsToGraph(state, graph);
|
||||
|
@ -12,6 +12,7 @@ import {
|
||||
CLIP_SKIP,
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
ONNX_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
@ -48,6 +49,8 @@ export const buildLinearTextToImageGraph = (
|
||||
throw new Error('No model found in state');
|
||||
}
|
||||
|
||||
const onnx_model_type = model.model_type.includes('onnx');
|
||||
const model_loader = onnx_model_type ? ONNX_MODEL_LOADER : MAIN_MODEL_LOADER;
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||
@ -58,12 +61,14 @@ export const buildLinearTextToImageGraph = (
|
||||
*/
|
||||
|
||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||
|
||||
// TODO: Actually create the graph correctly for ONNX
|
||||
const graph: NonNullableGraph = {
|
||||
id: TEXT_TO_IMAGE_GRAPH,
|
||||
nodes: {
|
||||
[MAIN_MODEL_LOADER]: {
|
||||
type: 'main_model_loader',
|
||||
id: MAIN_MODEL_LOADER,
|
||||
[model_loader]: {
|
||||
type: model_loader,
|
||||
id: model_loader,
|
||||
model,
|
||||
},
|
||||
[CLIP_SKIP]: {
|
||||
@ -72,12 +77,12 @@ export const buildLinearTextToImageGraph = (
|
||||
skipped_layers: clipSkip,
|
||||
},
|
||||
[POSITIVE_CONDITIONING]: {
|
||||
type: 'compel',
|
||||
type: onnx_model_type ? 'prompt_onnx' : 'compel',
|
||||
id: POSITIVE_CONDITIONING,
|
||||
prompt: positivePrompt,
|
||||
},
|
||||
[NEGATIVE_CONDITIONING]: {
|
||||
type: 'compel',
|
||||
type: onnx_model_type ? 'prompt_onnx' : 'compel',
|
||||
id: NEGATIVE_CONDITIONING,
|
||||
prompt: negativePrompt,
|
||||
},
|
||||
@ -89,14 +94,14 @@ export const buildLinearTextToImageGraph = (
|
||||
use_cpu,
|
||||
},
|
||||
[TEXT_TO_LATENTS]: {
|
||||
type: 't2l',
|
||||
type: onnx_model_type ? 't2l_onnx' : 't2l',
|
||||
id: TEXT_TO_LATENTS,
|
||||
cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
type: onnx_model_type ? 'l2i_onnx' : 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
@ -104,7 +109,7 @@ export const buildLinearTextToImageGraph = (
|
||||
edges: [
|
||||
{
|
||||
source: {
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
node_id: model_loader,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
@ -114,7 +119,7 @@ export const buildLinearTextToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
node_id: model_loader,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
@ -218,10 +223,10 @@ export const buildLinearTextToImageGraph = (
|
||||
});
|
||||
|
||||
// add LoRA support
|
||||
addLoRAsToGraph(state, graph, TEXT_TO_LATENTS);
|
||||
addLoRAsToGraph(state, graph, TEXT_TO_LATENTS, model_loader);
|
||||
|
||||
// optionally add custom VAE
|
||||
addVAEToGraph(state, graph);
|
||||
addVAEToGraph(state, graph, model_loader);
|
||||
|
||||
// add dynamic prompts - also sets up core iteration and seed
|
||||
addDynamicPromptsToGraph(state, graph);
|
||||
|
@ -10,6 +10,7 @@ export const RANDOM_INT = 'rand_int';
|
||||
export const RANGE_OF_SIZE = 'range_of_size';
|
||||
export const ITERATE = 'iterate';
|
||||
export const MAIN_MODEL_LOADER = 'main_model_loader';
|
||||
export const ONNX_MODEL_LOADER = 'onnx_model_loader';
|
||||
export const VAE_LOADER = 'vae_loader';
|
||||
export const LORA_LOADER = 'lora_loader';
|
||||
export const CLIP_SKIP = 'clip_skip';
|
||||
|
@ -15,8 +15,11 @@ import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainM
|
||||
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { forEach } from 'lodash-es';
|
||||
import {
|
||||
useGetMainModelsQuery,
|
||||
useGetOnnxModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
|
||||
|
||||
const selector = createSelector(
|
||||
@ -35,6 +38,9 @@ const ParamMainModelSelect = () => {
|
||||
const { data: mainModels, isLoading } = useGetMainModelsQuery(
|
||||
NON_REFINER_BASE_MODELS
|
||||
);
|
||||
const { data: onnxModels, isLoading: onnxLoading } = useGetOnnxModelsQuery(
|
||||
NON_REFINER_BASE_MODELS
|
||||
);
|
||||
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
|
||||
@ -59,17 +65,35 @@ const ParamMainModelSelect = () => {
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
forEach(onnxModels?.entities, (model, id) => {
|
||||
if (
|
||||
!model ||
|
||||
activeTabName === 'unifiedCanvas' ||
|
||||
activeTabName === 'img2img'
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.model_name,
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [mainModels, activeTabName]);
|
||||
}, [mainModels, onnxModels, activeTabName]);
|
||||
|
||||
// grab the full model entity from the RTK Query cache
|
||||
// TODO: maybe we should just store the full model entity in state?
|
||||
const selectedModel = useMemo(
|
||||
() =>
|
||||
mainModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ??
|
||||
(mainModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ||
|
||||
onnxModels?.entities[
|
||||
`${model?.base_model}/onnx/${model?.model_name}`
|
||||
]) ??
|
||||
null,
|
||||
[mainModels?.entities, model]
|
||||
[mainModels?.entities, model, onnxModels?.entities]
|
||||
);
|
||||
|
||||
const handleChangeModel = useCallback(
|
||||
@ -89,7 +113,7 @@ const ParamMainModelSelect = () => {
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return isLoading ? (
|
||||
return isLoading || onnxLoading ? (
|
||||
<IAIMantineSearchableSelect
|
||||
label={t('modelManager.model')}
|
||||
placeholder="Loading..."
|
||||
|
@ -1,10 +1,10 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { ImageDTO, MainModelField } from 'services/api/types';
|
||||
import { ImageDTO, MainModelField, OnnxModelField } from 'services/api/types';
|
||||
|
||||
export const initialImageSelected = createAction<ImageDTO | undefined>(
|
||||
'generation/initialImageSelected'
|
||||
);
|
||||
|
||||
export const modelSelected = createAction<MainModelField>(
|
||||
export const modelSelected = createAction<MainModelField | OnnxModelField>(
|
||||
'generation/modelSelected'
|
||||
);
|
||||
|
@ -3,7 +3,7 @@ import { createSlice } from '@reduxjs/toolkit';
|
||||
import { roundToMultiple } from 'common/util/roundDownToMultiple';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { clamp } from 'lodash-es';
|
||||
import { ImageDTO, MainModelField } from 'services/api/types';
|
||||
import { ImageDTO, MainModelField, OnnxModelField } from 'services/api/types';
|
||||
import { clipSkipMap } from '../types/constants';
|
||||
import {
|
||||
CfgScaleParam,
|
||||
@ -50,7 +50,7 @@ export interface GenerationState {
|
||||
shouldUseSymmetry: boolean;
|
||||
horizontalSymmetrySteps: number;
|
||||
verticalSymmetrySteps: number;
|
||||
model: MainModelField | null;
|
||||
model: MainModelField | OnnxModelField | null;
|
||||
vae: VaeModelParam | null;
|
||||
vaePrecision: PrecisionParam;
|
||||
seamlessXAxis: boolean;
|
||||
@ -272,11 +272,12 @@ export const generationSlice = createSlice({
|
||||
const defaultModel = action.payload.sd?.defaultModel;
|
||||
|
||||
if (defaultModel && !state.model) {
|
||||
const [base_model, _model_type, model_name] = defaultModel.split('/');
|
||||
const [base_model, model_type, model_name] = defaultModel.split('/');
|
||||
|
||||
const result = zMainModel.safeParse({
|
||||
model_name,
|
||||
base_model,
|
||||
model_type,
|
||||
});
|
||||
|
||||
if (result.success) {
|
||||
|
@ -210,6 +210,14 @@ export type HeightParam = z.infer<typeof zHeight>;
|
||||
export const isValidHeight = (val: unknown): val is HeightParam =>
|
||||
zHeight.safeParse(val).success;
|
||||
|
||||
const zModelType = z.enum([
|
||||
'vae',
|
||||
'lora',
|
||||
'onnx',
|
||||
'main',
|
||||
'controlnet',
|
||||
'embedding',
|
||||
]);
|
||||
const zBaseModel = z.enum(['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
||||
|
||||
export type BaseModelParam = z.infer<typeof zBaseModel>;
|
||||
@ -221,12 +229,18 @@ export type BaseModelParam = z.infer<typeof zBaseModel>;
|
||||
export const zMainModel = z.object({
|
||||
model_name: z.string().min(1),
|
||||
base_model: zBaseModel,
|
||||
model_type: zModelType,
|
||||
});
|
||||
|
||||
/**
|
||||
* Type alias for model parameter, inferred from its zod schema
|
||||
*/
|
||||
export type MainModelParam = z.infer<typeof zMainModel>;
|
||||
|
||||
/**
|
||||
* Type alias for model parameter, inferred from its zod schema
|
||||
*/
|
||||
export type OnnxModelParam = z.infer<typeof zMainModel>;
|
||||
/**
|
||||
* Validates/type-guards a value as a model parameter
|
||||
*/
|
||||
|
@ -8,11 +8,12 @@ export const modelIdToMainModelParam = (
|
||||
mainModelId: string
|
||||
): MainModelParam | undefined => {
|
||||
const log = logger('models');
|
||||
const [base_model, _model_type, model_name] = mainModelId.split('/');
|
||||
const [base_model, model_type, model_name] = mainModelId.split('/');
|
||||
|
||||
const result = zMainModel.safeParse({
|
||||
base_model,
|
||||
model_name,
|
||||
model_type,
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
|
@ -8,7 +8,9 @@ import { useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
MainModelConfigEntity,
|
||||
OnnxModelConfigEntity,
|
||||
useGetMainModelsQuery,
|
||||
useGetOnnxModelsQuery,
|
||||
useGetLoRAModelsQuery,
|
||||
LoRAModelConfigEntity,
|
||||
} from 'services/api/endpoints/models';
|
||||
@ -20,9 +22,9 @@ type ModelListProps = {
|
||||
setSelectedModelId: (name: string | undefined) => void;
|
||||
};
|
||||
|
||||
type ModelFormat = 'images' | 'checkpoint' | 'diffusers';
|
||||
type ModelFormat = 'images' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
|
||||
|
||||
type ModelType = 'main' | 'lora';
|
||||
type ModelType = 'main' | 'lora' | 'onnx';
|
||||
|
||||
type CombinedModelFormat = ModelFormat | 'lora';
|
||||
|
||||
@ -61,6 +63,18 @@ const ModelList = (props: ModelListProps) => {
|
||||
}),
|
||||
});
|
||||
|
||||
const { filteredOnnxModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredOnnxModels: modelsFilter(data, 'onnx', 'onnx', nameFilter),
|
||||
}),
|
||||
});
|
||||
|
||||
const { filteredOliveModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredOliveModels: modelsFilter(data, 'onnx', 'olive', nameFilter),
|
||||
}),
|
||||
});
|
||||
|
||||
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setNameFilter(e.target.value);
|
||||
}, []);
|
||||
@ -85,10 +99,17 @@ const ModelList = (props: ModelListProps) => {
|
||||
</IAIButton>
|
||||
<IAIButton
|
||||
size="sm"
|
||||
onClick={() => setModelFormatFilter('checkpoint')}
|
||||
isChecked={modelFormatFilter === 'checkpoint'}
|
||||
onClick={() => setModelFormatFilter('onnx')}
|
||||
isChecked={modelFormatFilter === 'onnx'}
|
||||
>
|
||||
{t('modelManager.checkpointModels')}
|
||||
{t('modelManager.onnxModels')}
|
||||
</IAIButton>
|
||||
<IAIButton
|
||||
size="sm"
|
||||
onClick={() => setModelFormatFilter('olive')}
|
||||
isChecked={modelFormatFilter === 'olive'}
|
||||
>
|
||||
{t('modelManager.oliveModels')}
|
||||
</IAIButton>
|
||||
<IAIButton
|
||||
size="sm"
|
||||
@ -147,6 +168,42 @@ const ModelList = (props: ModelListProps) => {
|
||||
</Flex>
|
||||
</StyledModelContainer>
|
||||
)}
|
||||
{['images', 'olive'].includes(modelFormatFilter) &&
|
||||
filteredOliveModels.length > 0 && (
|
||||
<StyledModelContainer>
|
||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||
<Text variant="subtext" fontSize="sm">
|
||||
Olives
|
||||
</Text>
|
||||
{filteredOliveModels.map((model) => (
|
||||
<ModelListItem
|
||||
key={model.id}
|
||||
model={model}
|
||||
isSelected={selectedModelId === model.id}
|
||||
setSelectedModelId={setSelectedModelId}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
</StyledModelContainer>
|
||||
)}
|
||||
{['images', 'onnx'].includes(modelFormatFilter) &&
|
||||
filteredOnnxModels.length > 0 && (
|
||||
<StyledModelContainer>
|
||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||
<Text variant="subtext" fontSize="sm">
|
||||
Onnx
|
||||
</Text>
|
||||
{filteredOnnxModels.map((model) => (
|
||||
<ModelListItem
|
||||
key={model.id}
|
||||
model={model}
|
||||
isSelected={selectedModelId === model.id}
|
||||
setSelectedModelId={setSelectedModelId}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
</StyledModelContainer>
|
||||
)}
|
||||
{['images', 'lora'].includes(modelFormatFilter) &&
|
||||
filteredLoraModels.length > 0 && (
|
||||
<StyledModelContainer>
|
||||
@ -173,7 +230,12 @@ const ModelList = (props: ModelListProps) => {
|
||||
|
||||
export default ModelList;
|
||||
|
||||
const modelsFilter = <T extends MainModelConfigEntity | LoRAModelConfigEntity>(
|
||||
const modelsFilter = <
|
||||
T extends
|
||||
| MainModelConfigEntity
|
||||
| LoRAModelConfigEntity
|
||||
| OnnxModelConfigEntity
|
||||
>(
|
||||
data: EntityState<T> | undefined,
|
||||
model_type: ModelType,
|
||||
model_format: ModelFormat | undefined,
|
||||
|
@ -10,9 +10,11 @@ import {
|
||||
ImportModelConfig,
|
||||
LoRAModelConfig,
|
||||
MainModelConfig,
|
||||
OnnxModelConfig,
|
||||
MergeModelConfig,
|
||||
TextualInversionModelConfig,
|
||||
VaeModelConfig,
|
||||
ModelType,
|
||||
} from 'services/api/types';
|
||||
|
||||
import queryString from 'query-string';
|
||||
@ -27,6 +29,8 @@ export type MainModelConfigEntity =
|
||||
| DiffusersModelConfigEntity
|
||||
| CheckpointModelConfigEntity;
|
||||
|
||||
export type OnnxModelConfigEntity = OnnxModelConfig & { id: string };
|
||||
|
||||
export type LoRAModelConfigEntity = LoRAModelConfig & { id: string };
|
||||
|
||||
export type ControlNetModelConfigEntity = ControlNetModelConfig & {
|
||||
@ -41,6 +45,7 @@ export type VaeModelConfigEntity = VaeModelConfig & { id: string };
|
||||
|
||||
type AnyModelConfigEntity =
|
||||
| MainModelConfigEntity
|
||||
| OnnxModelConfigEntity
|
||||
| LoRAModelConfigEntity
|
||||
| ControlNetModelConfigEntity
|
||||
| TextualInversionModelConfigEntity
|
||||
@ -66,6 +71,7 @@ type UpdateLoRAModelResponse = UpdateMainModelResponse;
|
||||
type DeleteMainModelArg = {
|
||||
base_model: BaseModelType;
|
||||
model_name: string;
|
||||
model_type: ModelType;
|
||||
};
|
||||
|
||||
type DeleteMainModelResponse = void;
|
||||
@ -119,6 +125,10 @@ type SearchFolderArg = operations['search_for_models']['parameters']['query'];
|
||||
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
});
|
||||
|
||||
const onnxModelsAdapter = createEntityAdapter<OnnxModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
});
|
||||
const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
});
|
||||
@ -156,6 +166,49 @@ const createModelEntities = <T extends AnyModelConfigEntity>(
|
||||
|
||||
export const modelsApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
getOnnxModels: build.query<
|
||||
EntityState<OnnxModelConfigEntity>,
|
||||
BaseModelType[]
|
||||
>({
|
||||
query: (base_models) => {
|
||||
const params = {
|
||||
model_type: 'onnx',
|
||||
base_models,
|
||||
};
|
||||
|
||||
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
||||
return `models/?${query}`;
|
||||
},
|
||||
providesTags: (result, error, arg) => {
|
||||
const tags: ApiFullTagDescription[] = [
|
||||
{ id: 'OnnxModel', type: LIST_TAG },
|
||||
];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: 'OnnxModel' as const,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
},
|
||||
transformResponse: (
|
||||
response: { models: OnnxModelConfig[] },
|
||||
meta,
|
||||
arg
|
||||
) => {
|
||||
const entities = createModelEntities<OnnxModelConfigEntity>(
|
||||
response.models
|
||||
);
|
||||
return onnxModelsAdapter.setAll(
|
||||
onnxModelsAdapter.getInitialState(),
|
||||
entities
|
||||
);
|
||||
},
|
||||
}),
|
||||
getMainModels: build.query<
|
||||
EntityState<MainModelConfigEntity>,
|
||||
BaseModelType[]
|
||||
@ -248,9 +301,9 @@ export const modelsApi = api.injectEndpoints({
|
||||
DeleteMainModelResponse,
|
||||
DeleteMainModelArg
|
||||
>({
|
||||
query: ({ base_model, model_name }) => {
|
||||
query: ({ base_model, model_name, model_type }) => {
|
||||
return {
|
||||
url: `models/${base_model}/main/${model_name}`,
|
||||
url: `models/${base_model}/${model_type}/${model_name}`,
|
||||
method: 'DELETE',
|
||||
};
|
||||
},
|
||||
@ -494,6 +547,7 @@ export const modelsApi = api.injectEndpoints({
|
||||
|
||||
export const {
|
||||
useGetMainModelsQuery,
|
||||
useGetOnnxModelsQuery,
|
||||
useGetControlNetModelsQuery,
|
||||
useGetLoRAModelsQuery,
|
||||
useGetTextualInversionModelsQuery,
|
||||
|
360
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
360
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
@ -1381,7 +1381,7 @@ export type components = {
|
||||
* @description The nodes in this graph
|
||||
*/
|
||||
nodes?: {
|
||||
[key: string]: (components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
|
||||
[key: string]: (components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
|
||||
};
|
||||
/**
|
||||
* Edges
|
||||
@ -1424,7 +1424,7 @@ export type components = {
|
||||
* @description The results of node executions
|
||||
*/
|
||||
results: {
|
||||
[key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["MetadataAccumulatorOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
|
||||
[key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["MetadataAccumulatorOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
|
||||
};
|
||||
/**
|
||||
* Errors
|
||||
@ -2562,10 +2562,10 @@ export type components = {
|
||||
/**
|
||||
* Infill Method
|
||||
* @description The method used to infill empty regions (px)
|
||||
* @default patchmatch
|
||||
* @default tile
|
||||
* @enum {string}
|
||||
*/
|
||||
infill_method?: "patchmatch" | "tile" | "solid";
|
||||
infill_method?: "tile" | "solid";
|
||||
/**
|
||||
* Inpaint Width
|
||||
* @description The width of the inpaint region (px)
|
||||
@ -3173,6 +3173,8 @@ export type components = {
|
||||
model_name: string;
|
||||
/** @description Base model */
|
||||
base_model: components["schemas"]["BaseModelType"];
|
||||
/** @description Model Type */
|
||||
model_type: components["schemas"]["ModelType"];
|
||||
};
|
||||
/**
|
||||
* MainModelLoaderInvocation
|
||||
@ -3618,7 +3620,7 @@ export type components = {
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
ModelType: "main" | "vae" | "lora" | "controlnet" | "embedding";
|
||||
ModelType: "onnx" | "main" | "vae" | "lora" | "controlnet" | "embedding";
|
||||
/**
|
||||
* ModelVariantType
|
||||
* @description An enumeration.
|
||||
@ -3628,7 +3630,7 @@ export type components = {
|
||||
/** ModelsList */
|
||||
ModelsList: {
|
||||
/** Models */
|
||||
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"])[];
|
||||
models: (components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"])[];
|
||||
};
|
||||
/**
|
||||
* MultiplyInvocation
|
||||
@ -3778,6 +3780,261 @@ export type components = {
|
||||
*/
|
||||
image_resolution?: number;
|
||||
};
|
||||
/**
|
||||
* ONNXLatentsToImageInvocation
|
||||
* @description Generates an image from latents.
|
||||
*/
|
||||
ONNXLatentsToImageInvocation: {
|
||||
/**
|
||||
* Id
|
||||
* @description The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Is Intermediate
|
||||
* @description Whether or not this node is an intermediate node.
|
||||
* @default false
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
/**
|
||||
* Type
|
||||
* @default l2i_onnx
|
||||
* @enum {string}
|
||||
*/
|
||||
type?: "l2i_onnx";
|
||||
/**
|
||||
* Latents
|
||||
* @description The latents to generate an image from
|
||||
*/
|
||||
latents?: components["schemas"]["LatentsField"];
|
||||
/**
|
||||
* Vae
|
||||
* @description Vae submodel
|
||||
*/
|
||||
vae?: components["schemas"]["VaeField"];
|
||||
/**
|
||||
* Metadata
|
||||
* @description Optional core metadata to be written to the image
|
||||
*/
|
||||
metadata?: components["schemas"]["CoreMetadata"];
|
||||
};
|
||||
/**
|
||||
* ONNXModelLoaderOutput
|
||||
* @description Model loader output
|
||||
*/
|
||||
ONNXModelLoaderOutput: {
|
||||
/**
|
||||
* Type
|
||||
* @default model_loader_output_onnx
|
||||
* @enum {string}
|
||||
*/
|
||||
type?: "model_loader_output_onnx";
|
||||
/**
|
||||
* Unet
|
||||
* @description UNet submodel
|
||||
*/
|
||||
unet?: components["schemas"]["UNetField"];
|
||||
/**
|
||||
* Clip
|
||||
* @description Tokenizer and text_encoder submodels
|
||||
*/
|
||||
clip?: components["schemas"]["ClipField"];
|
||||
/**
|
||||
* Vae Decoder
|
||||
* @description Vae submodel
|
||||
*/
|
||||
vae_decoder?: components["schemas"]["VaeField"];
|
||||
/**
|
||||
* Vae Encoder
|
||||
* @description Vae submodel
|
||||
*/
|
||||
vae_encoder?: components["schemas"]["VaeField"];
|
||||
};
|
||||
/**
|
||||
* ONNXPromptInvocation
|
||||
* @description A node to process inputs and produce outputs.
|
||||
* May use dependency injection in __init__ to receive providers.
|
||||
*/
|
||||
ONNXPromptInvocation: {
|
||||
/**
|
||||
* Id
|
||||
* @description The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Is Intermediate
|
||||
* @description Whether or not this node is an intermediate node.
|
||||
* @default false
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
/**
|
||||
* Type
|
||||
* @default prompt_onnx
|
||||
* @enum {string}
|
||||
*/
|
||||
type?: "prompt_onnx";
|
||||
/**
|
||||
* Prompt
|
||||
* @description Prompt
|
||||
* @default
|
||||
*/
|
||||
prompt?: string;
|
||||
/**
|
||||
* Clip
|
||||
* @description Clip to use
|
||||
*/
|
||||
clip?: components["schemas"]["ClipField"];
|
||||
};
|
||||
/**
|
||||
* ONNXSD1ModelLoaderInvocation
|
||||
* @description Loading submodels of selected model.
|
||||
*/
|
||||
ONNXSD1ModelLoaderInvocation: {
|
||||
/**
|
||||
* Id
|
||||
* @description The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Is Intermediate
|
||||
* @description Whether or not this node is an intermediate node.
|
||||
* @default false
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
/**
|
||||
* Type
|
||||
* @default sd1_model_loader_onnx
|
||||
* @enum {string}
|
||||
*/
|
||||
type?: "sd1_model_loader_onnx";
|
||||
/**
|
||||
* Model Name
|
||||
* @description Model to load
|
||||
* @default
|
||||
*/
|
||||
model_name?: string;
|
||||
};
|
||||
/** ONNXStableDiffusion1ModelConfig */
|
||||
ONNXStableDiffusion1ModelConfig: {
|
||||
/** Model Name */
|
||||
model_name: string;
|
||||
base_model: components["schemas"]["BaseModelType"];
|
||||
/**
|
||||
* Model Type
|
||||
* @enum {string}
|
||||
*/
|
||||
model_type: "onnx";
|
||||
/** Path */
|
||||
path: string;
|
||||
/** Description */
|
||||
description?: string;
|
||||
/**
|
||||
* Model Format
|
||||
* @enum {string}
|
||||
*/
|
||||
model_format: "onnx";
|
||||
error?: components["schemas"]["ModelError"];
|
||||
variant: components["schemas"]["ModelVariantType"];
|
||||
};
|
||||
/** ONNXStableDiffusion2ModelConfig */
|
||||
ONNXStableDiffusion2ModelConfig: {
|
||||
/** Model Name */
|
||||
model_name: string;
|
||||
base_model: components["schemas"]["BaseModelType"];
|
||||
/**
|
||||
* Model Type
|
||||
* @enum {string}
|
||||
*/
|
||||
model_type: "onnx";
|
||||
/** Path */
|
||||
path: string;
|
||||
/** Description */
|
||||
description?: string;
|
||||
/**
|
||||
* Model Format
|
||||
* @enum {string}
|
||||
*/
|
||||
model_format: "onnx";
|
||||
error?: components["schemas"]["ModelError"];
|
||||
variant: components["schemas"]["ModelVariantType"];
|
||||
prediction_type: components["schemas"]["SchedulerPredictionType"];
|
||||
/** Upcast Attention */
|
||||
upcast_attention: boolean;
|
||||
};
|
||||
/**
|
||||
* ONNXTextToLatentsInvocation
|
||||
* @description Generates latents from conditionings.
|
||||
*/
|
||||
ONNXTextToLatentsInvocation: {
|
||||
/**
|
||||
* Id
|
||||
* @description The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Is Intermediate
|
||||
* @description Whether or not this node is an intermediate node.
|
||||
* @default false
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
/**
|
||||
* Type
|
||||
* @default t2l_onnx
|
||||
* @enum {string}
|
||||
*/
|
||||
type?: "t2l_onnx";
|
||||
/**
|
||||
* Positive Conditioning
|
||||
* @description Positive conditioning for generation
|
||||
*/
|
||||
positive_conditioning?: components["schemas"]["ConditioningField"];
|
||||
/**
|
||||
* Negative Conditioning
|
||||
* @description Negative conditioning for generation
|
||||
*/
|
||||
negative_conditioning?: components["schemas"]["ConditioningField"];
|
||||
/**
|
||||
* Noise
|
||||
* @description The noise to use
|
||||
*/
|
||||
noise?: components["schemas"]["LatentsField"];
|
||||
/**
|
||||
* Steps
|
||||
* @description The number of steps to use to generate the image
|
||||
* @default 10
|
||||
*/
|
||||
steps?: number;
|
||||
/**
|
||||
* Cfg Scale
|
||||
* @description The Classifier-Free Guidance, higher values may result in a result closer to the prompt
|
||||
* @default 7.5
|
||||
*/
|
||||
cfg_scale?: number | (number)[];
|
||||
/**
|
||||
* Scheduler
|
||||
* @description The scheduler to use
|
||||
* @default euler
|
||||
* @enum {string}
|
||||
*/
|
||||
scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc";
|
||||
/**
|
||||
* Precision
|
||||
* @description The precision to use when generating latents
|
||||
* @default tensor(float16)
|
||||
* @enum {string}
|
||||
*/
|
||||
precision?: "tensor(bool)" | "tensor(int8)" | "tensor(uint8)" | "tensor(int16)" | "tensor(uint16)" | "tensor(int32)" | "tensor(uint32)" | "tensor(int64)" | "tensor(uint64)" | "tensor(float16)" | "tensor(float)" | "tensor(double)";
|
||||
/**
|
||||
* Unet
|
||||
* @description UNet submodel
|
||||
*/
|
||||
unet?: components["schemas"]["UNetField"];
|
||||
/**
|
||||
* Control
|
||||
* @description The control to use
|
||||
*/
|
||||
control?: components["schemas"]["ControlField"] | (components["schemas"]["ControlField"])[];
|
||||
};
|
||||
/**
|
||||
* OffsetPaginatedResults[BoardDTO]
|
||||
* @description Offset-paginated results
|
||||
@ -3830,6 +4087,49 @@ export type components = {
|
||||
*/
|
||||
total: number;
|
||||
};
|
||||
/**
|
||||
* OnnxModelField
|
||||
* @description Onnx model field
|
||||
*/
|
||||
OnnxModelField: {
|
||||
/**
|
||||
* Model Name
|
||||
* @description Name of the model
|
||||
*/
|
||||
model_name: string;
|
||||
/** @description Base model */
|
||||
base_model: components["schemas"]["BaseModelType"];
|
||||
/** @description Model Type */
|
||||
model_type: components["schemas"]["ModelType"];
|
||||
};
|
||||
/**
|
||||
* OnnxModelLoaderInvocation
|
||||
* @description Loads a main model, outputting its submodels.
|
||||
*/
|
||||
OnnxModelLoaderInvocation: {
|
||||
/**
|
||||
* Id
|
||||
* @description The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Is Intermediate
|
||||
* @description Whether or not this node is an intermediate node.
|
||||
* @default false
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
/**
|
||||
* Type
|
||||
* @default onnx_model_loader
|
||||
* @enum {string}
|
||||
*/
|
||||
type?: "onnx_model_loader";
|
||||
/**
|
||||
* Model
|
||||
* @description The model to load
|
||||
*/
|
||||
model: components["schemas"]["OnnxModelField"];
|
||||
};
|
||||
/**
|
||||
* OpenposeImageProcessorInvocation
|
||||
* @description Applies Openpose processing to image
|
||||
@ -4963,6 +5263,12 @@ export type components = {
|
||||
*/
|
||||
antialias?: boolean;
|
||||
};
|
||||
/**
|
||||
* SchedulerPredictionType
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
SchedulerPredictionType: "epsilon" | "v_prediction" | "sample";
|
||||
/**
|
||||
* SegmentAnythingProcessorInvocation
|
||||
* @description Applies segment anything processing to image
|
||||
@ -5273,7 +5579,7 @@ export type components = {
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
SubModelType: "unet" | "text_encoder" | "text_encoder_2" | "tokenizer" | "tokenizer_2" | "vae" | "scheduler" | "safety_checker";
|
||||
SubModelType: "unet" | "text_encoder" | "text_encoder_2" | "tokenizer" | "tokenizer_2" | "vae" | "vae_decoder" | "vae_encoder" | "scheduler" | "safety_checker";
|
||||
/**
|
||||
* SubtractInvocation
|
||||
* @description Subtracts two numbers
|
||||
@ -5586,29 +5892,35 @@ export type components = {
|
||||
image?: components["schemas"]["ImageField"];
|
||||
};
|
||||
/**
|
||||
* StableDiffusion2ModelFormat
|
||||
* StableDiffusionOnnxModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||
StableDiffusionOnnxModelFormat: "olive" | "onnx";
|
||||
/**
|
||||
* StableDiffusionXLModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusion1ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* ControlNetModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
ControlNetModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusion2ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusion1ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||
};
|
||||
responses: never;
|
||||
parameters: never;
|
||||
@ -5719,7 +6031,7 @@ export type operations = {
|
||||
};
|
||||
requestBody: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
||||
"application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
||||
};
|
||||
};
|
||||
responses: {
|
||||
@ -5756,7 +6068,7 @@ export type operations = {
|
||||
};
|
||||
requestBody: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
||||
"application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
||||
};
|
||||
};
|
||||
responses: {
|
||||
@ -6020,14 +6332,14 @@ export type operations = {
|
||||
};
|
||||
requestBody: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
||||
"application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
||||
};
|
||||
};
|
||||
responses: {
|
||||
/** @description The model was updated successfully */
|
||||
200: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
||||
"application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
||||
};
|
||||
};
|
||||
/** @description Bad request */
|
||||
@ -6058,7 +6370,7 @@ export type operations = {
|
||||
/** @description The model imported successfully */
|
||||
201: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
||||
"application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
||||
};
|
||||
};
|
||||
/** @description The model could not be found */
|
||||
@ -6084,14 +6396,14 @@ export type operations = {
|
||||
add_model: {
|
||||
requestBody: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
||||
"application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
||||
};
|
||||
};
|
||||
responses: {
|
||||
/** @description The model added successfully */
|
||||
201: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
||||
"application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
||||
};
|
||||
};
|
||||
/** @description The model could not be found */
|
||||
@ -6131,7 +6443,7 @@ export type operations = {
|
||||
/** @description Model converted successfully */
|
||||
200: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
||||
"application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
||||
};
|
||||
};
|
||||
/** @description Bad request */
|
||||
@ -6220,7 +6532,7 @@ export type operations = {
|
||||
/** @description Model converted successfully */
|
||||
200: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
||||
"application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
||||
};
|
||||
};
|
||||
/** @description Incompatible models */
|
||||
|
@ -32,6 +32,7 @@ export type ModelType = components['schemas']['ModelType'];
|
||||
export type SubModelType = components['schemas']['SubModelType'];
|
||||
export type BaseModelType = components['schemas']['BaseModelType'];
|
||||
export type MainModelField = components['schemas']['MainModelField'];
|
||||
export type OnnxModelField = components['schemas']['OnnxModelField'];
|
||||
export type VAEModelField = components['schemas']['VAEModelField'];
|
||||
export type LoRAModelField = components['schemas']['LoRAModelField'];
|
||||
export type ControlNetModelField =
|
||||
@ -58,6 +59,8 @@ export type DiffusersModelConfig =
|
||||
export type CheckpointModelConfig =
|
||||
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
|
||||
| components['schemas']['StableDiffusion2ModelCheckpointConfig']
|
||||
| components['schemas']['StableDiffusion2ModelDiffusersConfig'];
|
||||
export type OnnxModelConfig = components['schemas']['ONNXStableDiffusion1ModelConfig']
|
||||
| components['schemas']['StableDiffusionXLModelCheckpointConfig'];
|
||||
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
|
||||
export type AnyModelConfig =
|
||||
@ -65,7 +68,8 @@ export type AnyModelConfig =
|
||||
| VaeModelConfig
|
||||
| ControlNetModelConfig
|
||||
| TextualInversionModelConfig
|
||||
| MainModelConfig;
|
||||
| MainModelConfig
|
||||
| OnnxModelConfig;
|
||||
|
||||
export type MergeModelConfig = components['schemas']['Body_merge_models'];
|
||||
export type ConvertModelConfig = components['schemas']['Body_convert_model'];
|
||||
@ -127,6 +131,9 @@ export type ImageCollectionInvocation = TypeReq<
|
||||
export type MainModelLoaderInvocation = TypeReq<
|
||||
components['schemas']['MainModelLoaderInvocation']
|
||||
>;
|
||||
export type OnnxModelLoaderInvocation = TypeReq<
|
||||
components['schemas']['OnnxModelLoaderInvocation']
|
||||
>;
|
||||
export type LoraLoaderInvocation = TypeReq<
|
||||
components['schemas']['LoraLoaderInvocation']
|
||||
>;
|
||||
|
@ -61,6 +61,7 @@ dependencies = [
|
||||
"numpy",
|
||||
"npyscreen",
|
||||
"omegaconf",
|
||||
"onnx",
|
||||
"opencv-python",
|
||||
"pydantic==1.*",
|
||||
"picklescan",
|
||||
@ -103,6 +104,15 @@ dependencies = [
|
||||
"xformers~=0.0.19; sys_platform!='darwin'",
|
||||
"triton; sys_platform=='linux'",
|
||||
]
|
||||
"onnx" = [
|
||||
"onnxruntime",
|
||||
]
|
||||
"onnx-cuda" = [
|
||||
"onnxruntime-gpu",
|
||||
]
|
||||
"onnx-directml" = [
|
||||
"onnxruntime-directml",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@ -180,4 +190,4 @@ output = "coverage/index.xml"
|
||||
max-line-length = 120
|
||||
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
line-length = 120
|
||||
|
Loading…
Reference in New Issue
Block a user