mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
ONNX Support (#3562)
Note: this branch based on #3548, not on main While find out what needs to be done to implement onnx, found that I can do draft of it pretty quickly, so... here it is) Supports LoRA and TI. As example - cat with sadcatmeme lora: ![image](https://github.com/invoke-ai/InvokeAI/assets/7768370/dbd1a5df-0629-4741-94b3-8e09f4b4d5a3) ![image](https://github.com/invoke-ai/InvokeAI/assets/7768370/d918836c-fdc7-43c0-aa81-dde9182f2e0f)
This commit is contained in:
commit
81654daed7
@ -455,7 +455,7 @@ def get_torch_source() -> (Union[str, None], str):
|
|||||||
device = graphical_accelerator()
|
device = graphical_accelerator()
|
||||||
|
|
||||||
url = None
|
url = None
|
||||||
optional_modules = None
|
optional_modules = "[onnx]"
|
||||||
if OS == "Linux":
|
if OS == "Linux":
|
||||||
if device == "rocm":
|
if device == "rocm":
|
||||||
url = "https://download.pytorch.org/whl/rocm5.4.2"
|
url = "https://download.pytorch.org/whl/rocm5.4.2"
|
||||||
@ -464,7 +464,10 @@ def get_torch_source() -> (Union[str, None], str):
|
|||||||
|
|
||||||
if device == "cuda":
|
if device == "cuda":
|
||||||
url = "https://download.pytorch.org/whl/cu117"
|
url = "https://download.pytorch.org/whl/cu117"
|
||||||
optional_modules = "[xformers]"
|
optional_modules = "[xformers,onnx-cuda]"
|
||||||
|
if device == "cuda_and_dml":
|
||||||
|
url = "https://download.pytorch.org/whl/cu117"
|
||||||
|
optional_modules = "[xformers,onnx-directml]"
|
||||||
|
|
||||||
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
||||||
|
|
||||||
|
@ -167,6 +167,10 @@ def graphical_accelerator():
|
|||||||
"an [gold1 b]NVIDIA[/] GPU (using CUDA™)",
|
"an [gold1 b]NVIDIA[/] GPU (using CUDA™)",
|
||||||
"cuda",
|
"cuda",
|
||||||
)
|
)
|
||||||
|
nvidia_with_dml = (
|
||||||
|
"an [gold1 b]NVIDIA[/] GPU (using CUDA™, and DirectML™ for ONNX) -- ALPHA",
|
||||||
|
"cuda_and_dml",
|
||||||
|
)
|
||||||
amd = (
|
amd = (
|
||||||
"an [gold1 b]AMD[/] GPU (using ROCm™)",
|
"an [gold1 b]AMD[/] GPU (using ROCm™)",
|
||||||
"rocm",
|
"rocm",
|
||||||
@ -181,7 +185,7 @@ def graphical_accelerator():
|
|||||||
)
|
)
|
||||||
|
|
||||||
if OS == "Windows":
|
if OS == "Windows":
|
||||||
options = [nvidia, cpu]
|
options = [nvidia, nvidia_with_dml, cpu]
|
||||||
if OS == "Linux":
|
if OS == "Linux":
|
||||||
options = [nvidia, amd, cpu]
|
options = [nvidia, amd, cpu]
|
||||||
elif OS == "Darwin":
|
elif OS == "Darwin":
|
||||||
|
@ -1,6 +1,14 @@
|
|||||||
from typing import Literal, Optional, Union, List, Annotated
|
from typing import Literal, Optional, Union, List, Annotated
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||||
|
from .model import ClipField
|
||||||
|
|
||||||
|
from ...backend.util.devices import torch_dtype
|
||||||
|
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||||
|
from ...backend.model_management import BaseModelType, ModelType, SubModelType, ModelPatcher
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compel import Compel, ReturnedEmbeddingsType
|
from compel import Compel, ReturnedEmbeddingsType
|
||||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||||
|
@ -24,6 +24,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
|||||||
)
|
)
|
||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
|
from ...backend.model_management import ModelPatcher
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
|
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
|
||||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||||
|
@ -53,6 +53,7 @@ class MainModelField(BaseModel):
|
|||||||
|
|
||||||
model_name: str = Field(description="Name of the model")
|
model_name: str = Field(description="Name of the model")
|
||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
model_type: ModelType = Field(description="Model Type")
|
||||||
|
|
||||||
|
|
||||||
class LoRAModelField(BaseModel):
|
class LoRAModelField(BaseModel):
|
||||||
|
578
invokeai/app/invocations/onnx.py
Normal file
578
invokeai/app/invocations/onnx.py
Normal file
@ -0,0 +1,578 @@
|
|||||||
|
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
|
||||||
|
|
||||||
|
from contextlib import ExitStack
|
||||||
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
|
import re
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, validator
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
||||||
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
|
|
||||||
|
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
|
from ...backend.model_management import ONNXModelPatcher
|
||||||
|
from ...backend.util import choose_torch_device
|
||||||
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||||
|
from .compel import ConditioningField
|
||||||
|
from .controlnet_image_processors import ControlField
|
||||||
|
from .image import ImageOutput
|
||||||
|
from .model import ModelInfo, UNetField, VaeField
|
||||||
|
|
||||||
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
|
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||||
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
from .model import ClipField
|
||||||
|
from .latent import LatentsField, LatentsOutput, build_latents_output, get_scheduler, SAMPLER_NAME_VALUES
|
||||||
|
from .compel import CompelOutput
|
||||||
|
|
||||||
|
|
||||||
|
ORT_TO_NP_TYPE = {
|
||||||
|
"tensor(bool)": np.bool_,
|
||||||
|
"tensor(int8)": np.int8,
|
||||||
|
"tensor(uint8)": np.uint8,
|
||||||
|
"tensor(int16)": np.int16,
|
||||||
|
"tensor(uint16)": np.uint16,
|
||||||
|
"tensor(int32)": np.int32,
|
||||||
|
"tensor(uint32)": np.uint32,
|
||||||
|
"tensor(int64)": np.int64,
|
||||||
|
"tensor(uint64)": np.uint64,
|
||||||
|
"tensor(float16)": np.float16,
|
||||||
|
"tensor(float)": np.float32,
|
||||||
|
"tensor(double)": np.float64,
|
||||||
|
}
|
||||||
|
|
||||||
|
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXPromptInvocation(BaseInvocation):
|
||||||
|
type: Literal["prompt_onnx"] = "prompt_onnx"
|
||||||
|
|
||||||
|
prompt: str = Field(default="", description="Prompt")
|
||||||
|
clip: ClipField = Field(None, description="Clip to use")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
|
**self.clip.tokenizer.dict(),
|
||||||
|
)
|
||||||
|
text_encoder_info = context.services.model_manager.get_model(
|
||||||
|
**self.clip.text_encoder.dict(),
|
||||||
|
)
|
||||||
|
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
|
||||||
|
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
|
||||||
|
loras = [
|
||||||
|
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||||
|
for lora in self.clip.loras
|
||||||
|
]
|
||||||
|
|
||||||
|
ti_list = []
|
||||||
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||||
|
name = trigger[1:-1]
|
||||||
|
try:
|
||||||
|
ti_list.append(
|
||||||
|
# stack.enter_context(
|
||||||
|
# context.services.model_manager.get_model(
|
||||||
|
# model_name=name,
|
||||||
|
# base_model=self.clip.text_encoder.base_model,
|
||||||
|
# model_type=ModelType.TextualInversion,
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
context.services.model_manager.get_model(
|
||||||
|
model_name=name,
|
||||||
|
base_model=self.clip.text_encoder.base_model,
|
||||||
|
model_type=ModelType.TextualInversion,
|
||||||
|
).context.model
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# print(e)
|
||||||
|
# import traceback
|
||||||
|
# print(traceback.format_exc())
|
||||||
|
print(f'Warn: trigger: "{trigger}" not found')
|
||||||
|
if loras or ti_list:
|
||||||
|
text_encoder.release_session()
|
||||||
|
with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras), ONNXModelPatcher.apply_ti(
|
||||||
|
orig_tokenizer, text_encoder, ti_list
|
||||||
|
) as (tokenizer, ti_manager):
|
||||||
|
text_encoder.create_session()
|
||||||
|
|
||||||
|
# copy from
|
||||||
|
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L153
|
||||||
|
text_inputs = tokenizer(
|
||||||
|
self.prompt,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=tokenizer.model_max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="np",
|
||||||
|
)
|
||||||
|
text_input_ids = text_inputs.input_ids
|
||||||
|
"""
|
||||||
|
untruncated_ids = tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
|
||||||
|
|
||||||
|
if not np.array_equal(text_input_ids, untruncated_ids):
|
||||||
|
removed_text = self.tokenizer.batch_decode(
|
||||||
|
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||||
|
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
|
||||||
|
|
||||||
|
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||||
|
|
||||||
|
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||||
|
context.services.latents.save(conditioning_name, (prompt_embeds, None))
|
||||||
|
|
||||||
|
return CompelOutput(
|
||||||
|
conditioning=ConditioningField(
|
||||||
|
conditioning_name=conditioning_name,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Text to image
|
||||||
|
class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||||
|
"""Generates latents from conditionings."""
|
||||||
|
|
||||||
|
type: Literal["t2l_onnx"] = "t2l_onnx"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
# fmt: off
|
||||||
|
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||||
|
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||||
|
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||||
|
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||||
|
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
|
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||||
|
precision: PRECISION_VALUES = Field(default = "tensor(float16)", description="The precision to use when generating latents")
|
||||||
|
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||||
|
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||||
|
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
|
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
@validator("cfg_scale")
|
||||||
|
def ge_one(cls, v):
|
||||||
|
"""validate that all cfg_scale values are >= 1"""
|
||||||
|
if isinstance(v, list):
|
||||||
|
for i in v:
|
||||||
|
if i < 1:
|
||||||
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
|
else:
|
||||||
|
if v < 1:
|
||||||
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
|
return v
|
||||||
|
|
||||||
|
# Schema customisation
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"tags": ["latents"],
|
||||||
|
"type_hints": {
|
||||||
|
"model": "model",
|
||||||
|
"control": "control",
|
||||||
|
# "cfg_scale": "float",
|
||||||
|
"cfg_scale": "number",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# based on
|
||||||
|
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||||
|
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||||
|
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||||
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
if isinstance(c, torch.Tensor):
|
||||||
|
c = c.cpu().numpy()
|
||||||
|
if isinstance(uc, torch.Tensor):
|
||||||
|
uc = uc.cpu().numpy()
|
||||||
|
device = torch.device(choose_torch_device())
|
||||||
|
prompt_embeds = np.concatenate([uc, c])
|
||||||
|
|
||||||
|
latents = context.services.latents.get(self.noise.latents_name)
|
||||||
|
if isinstance(latents, torch.Tensor):
|
||||||
|
latents = latents.cpu().numpy()
|
||||||
|
|
||||||
|
# TODO: better execution device handling
|
||||||
|
latents = latents.astype(ORT_TO_NP_TYPE[self.precision])
|
||||||
|
|
||||||
|
# get the initial random noise unless the user supplied it
|
||||||
|
do_classifier_free_guidance = True
|
||||||
|
# latents_dtype = prompt_embeds.dtype
|
||||||
|
# latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
|
||||||
|
# if latents.shape != latents_shape:
|
||||||
|
# raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||||
|
|
||||||
|
scheduler = get_scheduler(
|
||||||
|
context=context,
|
||||||
|
scheduler_info=self.unet.scheduler,
|
||||||
|
scheduler_name=self.scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
def torch2numpy(latent: torch.Tensor):
|
||||||
|
return latent.cpu().numpy()
|
||||||
|
|
||||||
|
def numpy2torch(latent, device):
|
||||||
|
return torch.from_numpy(latent).to(device)
|
||||||
|
|
||||||
|
def dispatch_progress(
|
||||||
|
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
|
||||||
|
) -> None:
|
||||||
|
stable_diffusion_step_callback(
|
||||||
|
context=context,
|
||||||
|
intermediate_state=intermediate_state,
|
||||||
|
node=self.dict(),
|
||||||
|
source_node_id=source_node_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler.set_timesteps(self.steps)
|
||||||
|
latents = latents * np.float64(scheduler.init_noise_sigma)
|
||||||
|
|
||||||
|
extra_step_kwargs = dict()
|
||||||
|
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
|
extra_step_kwargs.update(
|
||||||
|
eta=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||||
|
|
||||||
|
with unet_info as unet, ExitStack() as stack:
|
||||||
|
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||||
|
loras = [
|
||||||
|
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||||
|
for lora in self.unet.loras
|
||||||
|
]
|
||||||
|
|
||||||
|
if loras:
|
||||||
|
unet.release_session()
|
||||||
|
with ONNXModelPatcher.apply_lora_unet(unet, loras):
|
||||||
|
# TODO:
|
||||||
|
_, _, h, w = latents.shape
|
||||||
|
unet.create_session(h, w)
|
||||||
|
|
||||||
|
timestep_dtype = next(
|
||||||
|
(input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)"
|
||||||
|
)
|
||||||
|
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||||
|
for i in tqdm(range(len(scheduler.timesteps))):
|
||||||
|
t = scheduler.timesteps[i]
|
||||||
|
# expand the latents if we are doing classifier free guidance
|
||||||
|
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
|
latent_model_input = scheduler.scale_model_input(numpy2torch(latent_model_input, device), t)
|
||||||
|
latent_model_input = latent_model_input.cpu().numpy()
|
||||||
|
|
||||||
|
# predict the noise residual
|
||||||
|
timestep = np.array([t], dtype=timestep_dtype)
|
||||||
|
noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
|
||||||
|
noise_pred = noise_pred[0]
|
||||||
|
|
||||||
|
# perform guidance
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||||
|
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
scheduler_output = scheduler.step(
|
||||||
|
numpy2torch(noise_pred, device), t, numpy2torch(latents, device), **extra_step_kwargs
|
||||||
|
)
|
||||||
|
latents = torch2numpy(scheduler_output.prev_sample)
|
||||||
|
|
||||||
|
state = PipelineIntermediateState(
|
||||||
|
run_id="test", step=i, timestep=timestep, latents=scheduler_output.prev_sample
|
||||||
|
)
|
||||||
|
dispatch_progress(self, context=context, source_node_id=source_node_id, intermediate_state=state)
|
||||||
|
|
||||||
|
# call the callback, if provided
|
||||||
|
# if callback is not None and i % callback_steps == 0:
|
||||||
|
# callback(i, t, latents)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
|
context.services.latents.save(name, latents)
|
||||||
|
return build_latents_output(latents_name=name, latents=torch.from_numpy(latents))
|
||||||
|
|
||||||
|
|
||||||
|
# Latent to image
|
||||||
|
class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||||
|
"""Generates an image from latents."""
|
||||||
|
|
||||||
|
type: Literal["l2i_onnx"] = "l2i_onnx"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
||||||
|
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||||
|
metadata: Optional[CoreMetadata] = Field(
|
||||||
|
default=None, description="Optional core metadata to be written to the image"
|
||||||
|
)
|
||||||
|
# tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
||||||
|
|
||||||
|
# Schema customisation
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"tags": ["latents", "image"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
|
if self.vae.vae.submodel != SubModelType.VaeDecoder:
|
||||||
|
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
|
||||||
|
|
||||||
|
vae_info = context.services.model_manager.get_model(
|
||||||
|
**self.vae.vae.dict(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# clear memory as vae decode can request a lot
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
with vae_info as vae:
|
||||||
|
vae.create_session()
|
||||||
|
|
||||||
|
# copied from
|
||||||
|
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L427
|
||||||
|
latents = 1 / 0.18215 * latents
|
||||||
|
# image = self.vae_decoder(latent_sample=latents)[0]
|
||||||
|
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||||
|
image = np.concatenate([vae(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])])
|
||||||
|
|
||||||
|
image = np.clip(image / 2 + 0.5, 0, 1)
|
||||||
|
image = image.transpose((0, 2, 3, 1))
|
||||||
|
image = VaeImageProcessor.numpy_to_pil(image)[0]
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
image_dto = context.services.images.create(
|
||||||
|
image=image,
|
||||||
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
|
metadata=self.metadata.dict() if self.metadata else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXModelLoaderOutput(BaseInvocationOutput):
|
||||||
|
"""Model loader output"""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx"
|
||||||
|
|
||||||
|
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||||
|
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||||
|
vae_decoder: VaeField = Field(default=None, description="Vae submodel")
|
||||||
|
vae_encoder: VaeField = Field(default=None, description="Vae submodel")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXSD1ModelLoaderInvocation(BaseInvocation):
|
||||||
|
"""Loading submodels of selected model."""
|
||||||
|
|
||||||
|
type: Literal["sd1_model_loader_onnx"] = "sd1_model_loader_onnx"
|
||||||
|
|
||||||
|
model_name: str = Field(default="", description="Model to load")
|
||||||
|
# TODO: precision?
|
||||||
|
|
||||||
|
# Schema customisation
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {"tags": ["model", "loader"], "type_hints": {"model_name": "model"}}, # TODO: rename to model_name?
|
||||||
|
}
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
||||||
|
model_name = "stable-diffusion-v1-5"
|
||||||
|
base_model = BaseModelType.StableDiffusion1
|
||||||
|
|
||||||
|
# TODO: not found exceptions
|
||||||
|
if not context.services.model_manager.model_exists(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=BaseModelType.StableDiffusion1,
|
||||||
|
model_type=ModelType.ONNX,
|
||||||
|
):
|
||||||
|
raise Exception(f"Unkown model name: {model_name}!")
|
||||||
|
|
||||||
|
return ONNXModelLoaderOutput(
|
||||||
|
unet=UNetField(
|
||||||
|
unet=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=ModelType.ONNX,
|
||||||
|
submodel=SubModelType.UNet,
|
||||||
|
),
|
||||||
|
scheduler=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=ModelType.ONNX,
|
||||||
|
submodel=SubModelType.Scheduler,
|
||||||
|
),
|
||||||
|
loras=[],
|
||||||
|
),
|
||||||
|
clip=ClipField(
|
||||||
|
tokenizer=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=ModelType.ONNX,
|
||||||
|
submodel=SubModelType.Tokenizer,
|
||||||
|
),
|
||||||
|
text_encoder=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=ModelType.ONNX,
|
||||||
|
submodel=SubModelType.TextEncoder,
|
||||||
|
),
|
||||||
|
loras=[],
|
||||||
|
),
|
||||||
|
vae_decoder=VaeField(
|
||||||
|
vae=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=ModelType.ONNX,
|
||||||
|
submodel=SubModelType.VaeDecoder,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
vae_encoder=VaeField(
|
||||||
|
vae=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=ModelType.ONNX,
|
||||||
|
submodel=SubModelType.VaeEncoder,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxModelField(BaseModel):
|
||||||
|
"""Onnx model field"""
|
||||||
|
|
||||||
|
model_name: str = Field(description="Name of the model")
|
||||||
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
model_type: ModelType = Field(description="Model Type")
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxModelLoaderInvocation(BaseInvocation):
|
||||||
|
"""Loads a main model, outputting its submodels."""
|
||||||
|
|
||||||
|
type: Literal["onnx_model_loader"] = "onnx_model_loader"
|
||||||
|
|
||||||
|
model: OnnxModelField = Field(description="The model to load")
|
||||||
|
|
||||||
|
# Schema customisation
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"title": "Onnx Model Loader",
|
||||||
|
"tags": ["model", "loader"],
|
||||||
|
"type_hints": {"model": "model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
||||||
|
base_model = self.model.base_model
|
||||||
|
model_name = self.model.model_name
|
||||||
|
model_type = ModelType.ONNX
|
||||||
|
|
||||||
|
# TODO: not found exceptions
|
||||||
|
if not context.services.model_manager.model_exists(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
):
|
||||||
|
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not context.services.model_manager.model_exists(
|
||||||
|
model_name=self.model_name,
|
||||||
|
model_type=SDModelType.Diffusers,
|
||||||
|
submodel=SDModelType.Tokenizer,
|
||||||
|
):
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not context.services.model_manager.model_exists(
|
||||||
|
model_name=self.model_name,
|
||||||
|
model_type=SDModelType.Diffusers,
|
||||||
|
submodel=SDModelType.TextEncoder,
|
||||||
|
):
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not context.services.model_manager.model_exists(
|
||||||
|
model_name=self.model_name,
|
||||||
|
model_type=SDModelType.Diffusers,
|
||||||
|
submodel=SDModelType.UNet,
|
||||||
|
):
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
return ONNXModelLoaderOutput(
|
||||||
|
unet=UNetField(
|
||||||
|
unet=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
submodel=SubModelType.UNet,
|
||||||
|
),
|
||||||
|
scheduler=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
submodel=SubModelType.Scheduler,
|
||||||
|
),
|
||||||
|
loras=[],
|
||||||
|
),
|
||||||
|
clip=ClipField(
|
||||||
|
tokenizer=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
submodel=SubModelType.Tokenizer,
|
||||||
|
),
|
||||||
|
text_encoder=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
submodel=SubModelType.TextEncoder,
|
||||||
|
),
|
||||||
|
loras=[],
|
||||||
|
skipped_layers=0,
|
||||||
|
),
|
||||||
|
vae_decoder=VaeField(
|
||||||
|
vae=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
submodel=SubModelType.VaeDecoder,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
vae_encoder=VaeField(
|
||||||
|
vae=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
submodel=SubModelType.VaeEncoder,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
@ -12,6 +12,7 @@ from typing import Optional, List, Dict, Callable, Union, Set
|
|||||||
import requests
|
import requests
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from diffusers import logging as dlogging
|
from diffusers import logging as dlogging
|
||||||
|
import onnx
|
||||||
from huggingface_hub import hf_hub_url, HfFolder, HfApi
|
from huggingface_hub import hf_hub_url, HfFolder, HfApi
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -302,8 +303,10 @@ class ModelInstall(object):
|
|||||||
|
|
||||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||||
staging = Path(staging)
|
staging = Path(staging)
|
||||||
if "model_index.json" in files:
|
if "model_index.json" in files and "unet/model.onnx" not in files:
|
||||||
location = self._download_hf_pipeline(repo_id, staging) # pipeline
|
location = self._download_hf_pipeline(repo_id, staging) # pipeline
|
||||||
|
elif "unet/model.onnx" in files:
|
||||||
|
location = self._download_hf_model(repo_id, files, staging)
|
||||||
else:
|
else:
|
||||||
for suffix in ["safetensors", "bin"]:
|
for suffix in ["safetensors", "bin"]:
|
||||||
if f"pytorch_lora_weights.{suffix}" in files:
|
if f"pytorch_lora_weights.{suffix}" in files:
|
||||||
@ -368,7 +371,7 @@ class ModelInstall(object):
|
|||||||
model_format=info.format,
|
model_format=info.format,
|
||||||
)
|
)
|
||||||
legacy_conf = None
|
legacy_conf = None
|
||||||
if info.model_type == ModelType.Main:
|
if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
|
||||||
attributes.update(
|
attributes.update(
|
||||||
dict(
|
dict(
|
||||||
variant=info.variant_type,
|
variant=info.variant_type,
|
||||||
@ -433,8 +436,13 @@ class ModelInstall(object):
|
|||||||
location = staging / name
|
location = staging / name
|
||||||
paths = list()
|
paths = list()
|
||||||
for filename in files:
|
for filename in files:
|
||||||
|
filePath = Path(filename)
|
||||||
p = hf_download_with_resume(
|
p = hf_download_with_resume(
|
||||||
repo_id, model_dir=location, model_name=filename, access_token=self.access_token
|
repo_id,
|
||||||
|
model_dir=location / filePath.parent,
|
||||||
|
model_name=filePath.name,
|
||||||
|
access_token=self.access_token,
|
||||||
|
subfolder=filePath.parent,
|
||||||
)
|
)
|
||||||
if p:
|
if p:
|
||||||
paths.append(p)
|
paths.append(p)
|
||||||
@ -482,11 +490,12 @@ def hf_download_with_resume(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
model_dest: Path = None,
|
model_dest: Path = None,
|
||||||
access_token: str = None,
|
access_token: str = None,
|
||||||
|
subfolder: str = None,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
|
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
|
||||||
os.makedirs(model_dir, exist_ok=True)
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
|
||||||
url = hf_hub_url(repo_id, model_name)
|
url = hf_hub_url(repo_id, model_name, subfolder=subfolder)
|
||||||
|
|
||||||
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
||||||
open_mode = "wb"
|
open_mode = "wb"
|
||||||
|
@ -3,6 +3,7 @@ Initialization file for invokeai.backend.model_management
|
|||||||
"""
|
"""
|
||||||
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
|
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
|
||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache
|
||||||
|
from .lora import ModelPatcher, ONNXModelPatcher
|
||||||
from .models import (
|
from .models import (
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
ModelType,
|
ModelType,
|
||||||
|
@ -6,11 +6,22 @@ from typing import Optional, Dict, Tuple, Any, Union, List
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from torch.utils.hooks import RemovableHandle
|
||||||
|
|
||||||
|
from diffusers.models import UNet2DConditionModel
|
||||||
|
from transformers import CLIPTextModel
|
||||||
|
from onnx import numpy_helper
|
||||||
|
from onnxruntime import OrtValue
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from compel.embeddings_provider import BaseTextualInversionManager
|
from compel.embeddings_provider import BaseTextualInversionManager
|
||||||
from diffusers.models import UNet2DConditionModel
|
from diffusers.models import UNet2DConditionModel
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
|
# TODO: rename and split this file
|
||||||
|
|
||||||
|
|
||||||
class LoRALayerBase:
|
class LoRALayerBase:
|
||||||
# rank: Optional[int]
|
# rank: Optional[int]
|
||||||
@ -698,3 +709,186 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
new_token_ids.extend(self.pad_tokens[token_id])
|
new_token_ids.extend(self.pad_tokens[token_id])
|
||||||
|
|
||||||
return new_token_ids
|
return new_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXModelPatcher:
|
||||||
|
from .models.base import IAIOnnxRuntimeModel, OnnxRuntimeModel
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
def apply_lora_unet(
|
||||||
|
cls,
|
||||||
|
unet: OnnxRuntimeModel,
|
||||||
|
loras: List[Tuple[LoRAModel, float]],
|
||||||
|
):
|
||||||
|
with cls.apply_lora(unet, loras, "lora_unet_"):
|
||||||
|
yield
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
def apply_lora_text_encoder(
|
||||||
|
cls,
|
||||||
|
text_encoder: OnnxRuntimeModel,
|
||||||
|
loras: List[Tuple[LoRAModel, float]],
|
||||||
|
):
|
||||||
|
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||||
|
yield
|
||||||
|
|
||||||
|
# based on
|
||||||
|
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
def apply_lora(
|
||||||
|
cls,
|
||||||
|
model: IAIOnnxRuntimeModel,
|
||||||
|
loras: List[Tuple[LoraModel, float]],
|
||||||
|
prefix: str,
|
||||||
|
):
|
||||||
|
from .models.base import IAIOnnxRuntimeModel
|
||||||
|
|
||||||
|
if not isinstance(model, IAIOnnxRuntimeModel):
|
||||||
|
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
||||||
|
|
||||||
|
orig_weights = dict()
|
||||||
|
|
||||||
|
try:
|
||||||
|
blended_loras = dict()
|
||||||
|
|
||||||
|
for lora, lora_weight in loras:
|
||||||
|
for layer_key, layer in lora.layers.items():
|
||||||
|
if not layer_key.startswith(prefix):
|
||||||
|
continue
|
||||||
|
|
||||||
|
layer.to(dtype=torch.float32)
|
||||||
|
layer_key = layer_key.replace(prefix, "")
|
||||||
|
layer_weight = layer.get_weight().detach().cpu().numpy() * lora_weight
|
||||||
|
if layer_key is blended_loras:
|
||||||
|
blended_loras[layer_key] += layer_weight
|
||||||
|
else:
|
||||||
|
blended_loras[layer_key] = layer_weight
|
||||||
|
|
||||||
|
node_names = dict()
|
||||||
|
for node in model.nodes.values():
|
||||||
|
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name
|
||||||
|
|
||||||
|
for layer_key, lora_weight in blended_loras.items():
|
||||||
|
conv_key = layer_key + "_Conv"
|
||||||
|
gemm_key = layer_key + "_Gemm"
|
||||||
|
matmul_key = layer_key + "_MatMul"
|
||||||
|
|
||||||
|
if conv_key in node_names or gemm_key in node_names:
|
||||||
|
if conv_key in node_names:
|
||||||
|
conv_node = model.nodes[node_names[conv_key]]
|
||||||
|
else:
|
||||||
|
conv_node = model.nodes[node_names[gemm_key]]
|
||||||
|
|
||||||
|
weight_name = [n for n in conv_node.input if ".weight" in n][0]
|
||||||
|
orig_weight = model.tensors[weight_name]
|
||||||
|
|
||||||
|
if orig_weight.shape[-2:] == (1, 1):
|
||||||
|
if lora_weight.shape[-2:] == (1, 1):
|
||||||
|
new_weight = orig_weight.squeeze((3, 2)) + lora_weight.squeeze((3, 2))
|
||||||
|
else:
|
||||||
|
new_weight = orig_weight.squeeze((3, 2)) + lora_weight
|
||||||
|
|
||||||
|
new_weight = np.expand_dims(new_weight, (2, 3))
|
||||||
|
else:
|
||||||
|
if orig_weight.shape != lora_weight.shape:
|
||||||
|
new_weight = orig_weight + lora_weight.reshape(orig_weight.shape)
|
||||||
|
else:
|
||||||
|
new_weight = orig_weight + lora_weight
|
||||||
|
|
||||||
|
orig_weights[weight_name] = orig_weight
|
||||||
|
model.tensors[weight_name] = new_weight.astype(orig_weight.dtype)
|
||||||
|
|
||||||
|
elif matmul_key in node_names:
|
||||||
|
weight_node = model.nodes[node_names[matmul_key]]
|
||||||
|
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
|
||||||
|
|
||||||
|
orig_weight = model.tensors[matmul_name]
|
||||||
|
new_weight = orig_weight + lora_weight.transpose()
|
||||||
|
|
||||||
|
orig_weights[matmul_name] = orig_weight
|
||||||
|
model.tensors[matmul_name] = new_weight.astype(orig_weight.dtype)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# warn? err?
|
||||||
|
pass
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# restore original weights
|
||||||
|
for name, orig_weight in orig_weights.items():
|
||||||
|
model.tensors[name] = orig_weight
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
def apply_ti(
|
||||||
|
cls,
|
||||||
|
tokenizer: CLIPTokenizer,
|
||||||
|
text_encoder: IAIOnnxRuntimeModel,
|
||||||
|
ti_list: List[Any],
|
||||||
|
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
||||||
|
from .models.base import IAIOnnxRuntimeModel
|
||||||
|
|
||||||
|
if not isinstance(text_encoder, IAIOnnxRuntimeModel):
|
||||||
|
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
||||||
|
|
||||||
|
orig_embeddings = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
ti_tokenizer = copy.deepcopy(tokenizer)
|
||||||
|
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||||
|
|
||||||
|
def _get_trigger(ti, index):
|
||||||
|
trigger = ti.name
|
||||||
|
if index > 0:
|
||||||
|
trigger += f"-!pad-{i}"
|
||||||
|
return f"<{trigger}>"
|
||||||
|
|
||||||
|
# modify tokenizer
|
||||||
|
new_tokens_added = 0
|
||||||
|
for ti in ti_list:
|
||||||
|
for i in range(ti.embedding.shape[0]):
|
||||||
|
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
|
||||||
|
|
||||||
|
# modify text_encoder
|
||||||
|
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
||||||
|
|
||||||
|
embeddings = np.concatenate(
|
||||||
|
(np.copy(orig_embeddings), np.zeros((new_tokens_added, orig_embeddings.shape[1]))),
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
for ti in ti_list:
|
||||||
|
ti_tokens = []
|
||||||
|
for i in range(ti.embedding.shape[0]):
|
||||||
|
embedding = ti.embedding[i].detach().numpy()
|
||||||
|
trigger = _get_trigger(ti, i)
|
||||||
|
|
||||||
|
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||||
|
if token_id == ti_tokenizer.unk_token_id:
|
||||||
|
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
|
||||||
|
|
||||||
|
if embeddings[token_id].shape != embedding.shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {embeddings[token_id].shape[0]}."
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings[token_id] = embedding
|
||||||
|
ti_tokens.append(token_id)
|
||||||
|
|
||||||
|
if len(ti_tokens) > 1:
|
||||||
|
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
|
||||||
|
|
||||||
|
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = embeddings.astype(
|
||||||
|
orig_embeddings.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
yield ti_tokenizer, ti_manager
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# restore
|
||||||
|
if orig_embeddings is not None:
|
||||||
|
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = orig_embeddings
|
||||||
|
@ -360,7 +360,8 @@ class ModelCache(object):
|
|||||||
# 2 refs:
|
# 2 refs:
|
||||||
# 1 from cache_entry
|
# 1 from cache_entry
|
||||||
# 1 from getrefcount function
|
# 1 from getrefcount function
|
||||||
if not cache_entry.locked and refs <= 2:
|
# 1 from onnx runtime object
|
||||||
|
if not cache_entry.locked and refs <= 3 if "onnx" in model_key else 2:
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||||
)
|
)
|
||||||
|
@ -27,7 +27,7 @@ class ModelProbeInfo(object):
|
|||||||
variant_type: ModelVariantType
|
variant_type: ModelVariantType
|
||||||
prediction_type: SchedulerPredictionType
|
prediction_type: SchedulerPredictionType
|
||||||
upcast_attention: bool
|
upcast_attention: bool
|
||||||
format: Literal["diffusers", "checkpoint", "lycoris"]
|
format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"]
|
||||||
image_size: int
|
image_size: int
|
||||||
|
|
||||||
|
|
||||||
@ -41,6 +41,7 @@ class ModelProbe(object):
|
|||||||
PROBES = {
|
PROBES = {
|
||||||
"diffusers": {},
|
"diffusers": {},
|
||||||
"checkpoint": {},
|
"checkpoint": {},
|
||||||
|
"onnx": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
CLASS2TYPE = {
|
CLASS2TYPE = {
|
||||||
@ -53,7 +54,9 @@ class ModelProbe(object):
|
|||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_probe(cls, format: Literal["diffusers", "checkpoint"], model_type: ModelType, probe_class: ProbeBase):
|
def register_probe(
|
||||||
|
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase
|
||||||
|
):
|
||||||
cls.PROBES[format][model_type] = probe_class
|
cls.PROBES[format][model_type] = probe_class
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -95,6 +98,7 @@ class ModelProbe(object):
|
|||||||
if format_type == "diffusers"
|
if format_type == "diffusers"
|
||||||
else cls.get_model_type_from_checkpoint(model_path, model)
|
else cls.get_model_type_from_checkpoint(model_path, model)
|
||||||
)
|
)
|
||||||
|
format_type = "onnx" if model_type == ModelType.ONNX else format_type
|
||||||
probe_class = cls.PROBES[format_type].get(model_type)
|
probe_class = cls.PROBES[format_type].get(model_type)
|
||||||
if not probe_class:
|
if not probe_class:
|
||||||
return None
|
return None
|
||||||
@ -168,6 +172,8 @@ class ModelProbe(object):
|
|||||||
if model:
|
if model:
|
||||||
class_name = model.__class__.__name__
|
class_name = model.__class__.__name__
|
||||||
else:
|
else:
|
||||||
|
if (folder_path / "unet/model.onnx").exists():
|
||||||
|
return ModelType.ONNX
|
||||||
if (folder_path / "learned_embeds.bin").exists():
|
if (folder_path / "learned_embeds.bin").exists():
|
||||||
return ModelType.TextualInversion
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
@ -460,6 +466,17 @@ class TextualInversionFolderProbe(FolderProbeBase):
|
|||||||
return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type()
|
return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type()
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXFolderProbe(FolderProbeBase):
|
||||||
|
def get_format(self) -> str:
|
||||||
|
return "onnx"
|
||||||
|
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
|
||||||
|
def get_variant_type(self) -> ModelVariantType:
|
||||||
|
return ModelVariantType.Normal
|
||||||
|
|
||||||
|
|
||||||
class ControlNetFolderProbe(FolderProbeBase):
|
class ControlNetFolderProbe(FolderProbeBase):
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
config_file = self.folder_path / "config.json"
|
config_file = self.folder_path / "config.json"
|
||||||
@ -497,3 +514,4 @@ ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
|||||||
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||||
|
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
||||||
|
@ -23,8 +23,11 @@ from .lora import LoRAModel
|
|||||||
from .controlnet import ControlNetModel # TODO:
|
from .controlnet import ControlNetModel # TODO:
|
||||||
from .textual_inversion import TextualInversionModel
|
from .textual_inversion import TextualInversionModel
|
||||||
|
|
||||||
|
from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
BaseModelType.StableDiffusion1: {
|
BaseModelType.StableDiffusion1: {
|
||||||
|
ModelType.ONNX: ONNXStableDiffusion1Model,
|
||||||
ModelType.Main: StableDiffusion1Model,
|
ModelType.Main: StableDiffusion1Model,
|
||||||
ModelType.Vae: VaeModel,
|
ModelType.Vae: VaeModel,
|
||||||
ModelType.Lora: LoRAModel,
|
ModelType.Lora: LoRAModel,
|
||||||
@ -32,6 +35,7 @@ MODEL_CLASSES = {
|
|||||||
ModelType.TextualInversion: TextualInversionModel,
|
ModelType.TextualInversion: TextualInversionModel,
|
||||||
},
|
},
|
||||||
BaseModelType.StableDiffusion2: {
|
BaseModelType.StableDiffusion2: {
|
||||||
|
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||||
ModelType.Main: StableDiffusion2Model,
|
ModelType.Main: StableDiffusion2Model,
|
||||||
ModelType.Vae: VaeModel,
|
ModelType.Vae: VaeModel,
|
||||||
ModelType.Lora: LoRAModel,
|
ModelType.Lora: LoRAModel,
|
||||||
@ -45,6 +49,7 @@ MODEL_CLASSES = {
|
|||||||
ModelType.Lora: LoRAModel,
|
ModelType.Lora: LoRAModel,
|
||||||
ModelType.ControlNet: ControlNetModel,
|
ModelType.ControlNet: ControlNetModel,
|
||||||
ModelType.TextualInversion: TextualInversionModel,
|
ModelType.TextualInversion: TextualInversionModel,
|
||||||
|
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||||
},
|
},
|
||||||
BaseModelType.StableDiffusionXLRefiner: {
|
BaseModelType.StableDiffusionXLRefiner: {
|
||||||
ModelType.Main: StableDiffusionXLModel,
|
ModelType.Main: StableDiffusionXLModel,
|
||||||
@ -53,6 +58,7 @@ MODEL_CLASSES = {
|
|||||||
ModelType.Lora: LoRAModel,
|
ModelType.Lora: LoRAModel,
|
||||||
ModelType.ControlNet: ControlNetModel,
|
ModelType.ControlNet: ControlNetModel,
|
||||||
ModelType.TextualInversion: TextualInversionModel,
|
ModelType.TextualInversion: TextualInversionModel,
|
||||||
|
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||||
},
|
},
|
||||||
# BaseModelType.Kandinsky2_1: {
|
# BaseModelType.Kandinsky2_1: {
|
||||||
# ModelType.Main: Kandinsky2_1Model,
|
# ModelType.Main: Kandinsky2_1Model,
|
||||||
|
@ -8,13 +8,23 @@ from abc import ABCMeta, abstractmethod
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
from diffusers import DiffusionPipeline, ConfigMixin
|
from pathlib import Path
|
||||||
|
from diffusers import DiffusionPipeline, ConfigMixin, OnnxRuntimeModel
|
||||||
|
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
from onnx import numpy_helper
|
||||||
|
from onnxruntime import (
|
||||||
|
InferenceSession,
|
||||||
|
SessionOptions,
|
||||||
|
get_available_providers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DuplicateModelException(Exception):
|
class DuplicateModelException(Exception):
|
||||||
pass
|
pass
|
||||||
@ -37,6 +47,7 @@ class BaseModelType(str, Enum):
|
|||||||
|
|
||||||
|
|
||||||
class ModelType(str, Enum):
|
class ModelType(str, Enum):
|
||||||
|
ONNX = "onnx"
|
||||||
Main = "main"
|
Main = "main"
|
||||||
Vae = "vae"
|
Vae = "vae"
|
||||||
Lora = "lora"
|
Lora = "lora"
|
||||||
@ -51,6 +62,8 @@ class SubModelType(str, Enum):
|
|||||||
Tokenizer = "tokenizer"
|
Tokenizer = "tokenizer"
|
||||||
Tokenizer2 = "tokenizer_2"
|
Tokenizer2 = "tokenizer_2"
|
||||||
Vae = "vae"
|
Vae = "vae"
|
||||||
|
VaeDecoder = "vae_decoder"
|
||||||
|
VaeEncoder = "vae_encoder"
|
||||||
Scheduler = "scheduler"
|
Scheduler = "scheduler"
|
||||||
SafetyChecker = "safety_checker"
|
SafetyChecker = "safety_checker"
|
||||||
# MoVQ = "movq"
|
# MoVQ = "movq"
|
||||||
@ -362,6 +375,8 @@ def calc_model_size_by_data(model) -> int:
|
|||||||
return _calc_pipeline_by_data(model)
|
return _calc_pipeline_by_data(model)
|
||||||
elif isinstance(model, torch.nn.Module):
|
elif isinstance(model, torch.nn.Module):
|
||||||
return _calc_model_by_data(model)
|
return _calc_model_by_data(model)
|
||||||
|
elif isinstance(model, IAIOnnxRuntimeModel):
|
||||||
|
return _calc_onnx_model_by_data(model)
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@ -382,6 +397,12 @@ def _calc_model_by_data(model) -> int:
|
|||||||
return mem
|
return mem
|
||||||
|
|
||||||
|
|
||||||
|
def _calc_onnx_model_by_data(model) -> int:
|
||||||
|
tensor_size = model.tensors.size() * 2 # The session doubles this
|
||||||
|
mem = tensor_size # in bytes
|
||||||
|
return mem
|
||||||
|
|
||||||
|
|
||||||
def _fast_safetensors_reader(path: str):
|
def _fast_safetensors_reader(path: str):
|
||||||
checkpoint = dict()
|
checkpoint = dict()
|
||||||
device = torch.device("meta")
|
device = torch.device("meta")
|
||||||
@ -449,3 +470,208 @@ class SilenceWarnings(object):
|
|||||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
transformers_logging.set_verbosity(self.transformers_verbosity)
|
||||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
||||||
warnings.simplefilter("default")
|
warnings.simplefilter("default")
|
||||||
|
|
||||||
|
|
||||||
|
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||||
|
|
||||||
|
|
||||||
|
class IAIOnnxRuntimeModel:
|
||||||
|
class _tensor_access:
|
||||||
|
def __init__(self, model):
|
||||||
|
self.model = model
|
||||||
|
self.indexes = dict()
|
||||||
|
for idx, obj in enumerate(self.model.proto.graph.initializer):
|
||||||
|
self.indexes[obj.name] = idx
|
||||||
|
|
||||||
|
def __getitem__(self, key: str):
|
||||||
|
value = self.model.proto.graph.initializer[self.indexes[key]]
|
||||||
|
return numpy_helper.to_array(value)
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, value: np.ndarray):
|
||||||
|
new_node = numpy_helper.from_array(value)
|
||||||
|
# set_external_data(new_node, location="in-memory-location")
|
||||||
|
new_node.name = key
|
||||||
|
# new_node.ClearField("raw_data")
|
||||||
|
del self.model.proto.graph.initializer[self.indexes[key]]
|
||||||
|
self.model.proto.graph.initializer.insert(self.indexes[key], new_node)
|
||||||
|
# self.model.data[key] = OrtValue.ortvalue_from_numpy(value)
|
||||||
|
|
||||||
|
# __delitem__
|
||||||
|
|
||||||
|
def __contains__(self, key: str):
|
||||||
|
return self.indexes[key] in self.model.proto.graph.initializer
|
||||||
|
|
||||||
|
def items(self):
|
||||||
|
raise NotImplementedError("tensor.items")
|
||||||
|
# return [(obj.name, obj) for obj in self.raw_proto]
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return self.indexes.keys()
|
||||||
|
|
||||||
|
def values(self):
|
||||||
|
raise NotImplementedError("tensor.values")
|
||||||
|
# return [obj for obj in self.raw_proto]
|
||||||
|
|
||||||
|
def size(self):
|
||||||
|
bytesSum = 0
|
||||||
|
for node in self.model.proto.graph.initializer:
|
||||||
|
bytesSum += sys.getsizeof(node.raw_data)
|
||||||
|
return bytesSum
|
||||||
|
|
||||||
|
class _access_helper:
|
||||||
|
def __init__(self, raw_proto):
|
||||||
|
self.indexes = dict()
|
||||||
|
self.raw_proto = raw_proto
|
||||||
|
for idx, obj in enumerate(raw_proto):
|
||||||
|
self.indexes[obj.name] = idx
|
||||||
|
|
||||||
|
def __getitem__(self, key: str):
|
||||||
|
return self.raw_proto[self.indexes[key]]
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, value):
|
||||||
|
index = self.indexes[key]
|
||||||
|
del self.raw_proto[index]
|
||||||
|
self.raw_proto.insert(index, value)
|
||||||
|
|
||||||
|
# __delitem__
|
||||||
|
|
||||||
|
def __contains__(self, key: str):
|
||||||
|
return key in self.indexes
|
||||||
|
|
||||||
|
def items(self):
|
||||||
|
return [(obj.name, obj) for obj in self.raw_proto]
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return self.indexes.keys()
|
||||||
|
|
||||||
|
def values(self):
|
||||||
|
return [obj for obj in self.raw_proto]
|
||||||
|
|
||||||
|
def __init__(self, model_path: str, provider: Optional[str]):
|
||||||
|
self.path = model_path
|
||||||
|
self.session = None
|
||||||
|
self.provider = provider
|
||||||
|
"""
|
||||||
|
self.data_path = self.path + "_data"
|
||||||
|
if not os.path.exists(self.data_path):
|
||||||
|
print(f"Moving model tensors to separate file: {self.data_path}")
|
||||||
|
tmp_proto = onnx.load(model_path, load_external_data=True)
|
||||||
|
onnx.save_model(tmp_proto, self.path, save_as_external_data=True, all_tensors_to_one_file=True, location=os.path.basename(self.data_path), size_threshold=1024, convert_attribute=False)
|
||||||
|
del tmp_proto
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
self.proto = onnx.load(model_path, load_external_data=False)
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.proto = onnx.load(model_path, load_external_data=True)
|
||||||
|
# self.data = dict()
|
||||||
|
# for tensor in self.proto.graph.initializer:
|
||||||
|
# name = tensor.name
|
||||||
|
|
||||||
|
# if tensor.HasField("raw_data"):
|
||||||
|
# npt = numpy_helper.to_array(tensor)
|
||||||
|
# orv = OrtValue.ortvalue_from_numpy(npt)
|
||||||
|
# # self.data[name] = orv
|
||||||
|
# # set_external_data(tensor, location="in-memory-location")
|
||||||
|
# tensor.name = name
|
||||||
|
# # tensor.ClearField("raw_data")
|
||||||
|
|
||||||
|
self.nodes = self._access_helper(self.proto.graph.node)
|
||||||
|
# self.initializers = self._access_helper(self.proto.graph.initializer)
|
||||||
|
# print(self.proto.graph.input)
|
||||||
|
# print(self.proto.graph.initializer)
|
||||||
|
|
||||||
|
self.tensors = self._tensor_access(self)
|
||||||
|
|
||||||
|
# TODO: integrate with model manager/cache
|
||||||
|
def create_session(self, height=None, width=None):
|
||||||
|
if self.session is None or self.session_width != width or self.session_height != height:
|
||||||
|
# onnx.save(self.proto, "tmp.onnx")
|
||||||
|
# onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
|
||||||
|
# TODO: something to be able to get weight when they already moved outside of model proto
|
||||||
|
# (trimmed_model, external_data) = buffer_external_data_tensors(self.proto)
|
||||||
|
sess = SessionOptions()
|
||||||
|
# self._external_data.update(**external_data)
|
||||||
|
# sess.add_external_initializers(list(self.data.keys()), list(self.data.values()))
|
||||||
|
# sess.enable_profiling = True
|
||||||
|
|
||||||
|
# sess.intra_op_num_threads = 1
|
||||||
|
# sess.inter_op_num_threads = 1
|
||||||
|
# sess.execution_mode = ExecutionMode.ORT_SEQUENTIAL
|
||||||
|
# sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
|
# sess.enable_cpu_mem_arena = True
|
||||||
|
# sess.enable_mem_pattern = True
|
||||||
|
# sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code
|
||||||
|
self.session_height = height
|
||||||
|
self.session_width = width
|
||||||
|
if height and width:
|
||||||
|
sess.add_free_dimension_override_by_name("unet_sample_batch", 2)
|
||||||
|
sess.add_free_dimension_override_by_name("unet_sample_channels", 4)
|
||||||
|
sess.add_free_dimension_override_by_name("unet_hidden_batch", 2)
|
||||||
|
sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
|
||||||
|
sess.add_free_dimension_override_by_name("unet_sample_height", self.session_height)
|
||||||
|
sess.add_free_dimension_override_by_name("unet_sample_width", self.session_width)
|
||||||
|
sess.add_free_dimension_override_by_name("unet_time_batch", 1)
|
||||||
|
providers = []
|
||||||
|
if self.provider:
|
||||||
|
providers.append(self.provider)
|
||||||
|
else:
|
||||||
|
providers = get_available_providers()
|
||||||
|
if "TensorrtExecutionProvider" in providers:
|
||||||
|
providers.remove("TensorrtExecutionProvider")
|
||||||
|
try:
|
||||||
|
self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
# self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
|
||||||
|
# self.io_binding = self.session.io_binding()
|
||||||
|
|
||||||
|
def release_session(self):
|
||||||
|
self.session = None
|
||||||
|
import gc
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
return
|
||||||
|
|
||||||
|
def __call__(self, **kwargs):
|
||||||
|
if self.session is None:
|
||||||
|
raise Exception("You should call create_session before running model")
|
||||||
|
|
||||||
|
inputs = {k: np.array(v) for k, v in kwargs.items()}
|
||||||
|
output_names = self.session.get_outputs()
|
||||||
|
# for k in inputs:
|
||||||
|
# self.io_binding.bind_cpu_input(k, inputs[k])
|
||||||
|
# for name in output_names:
|
||||||
|
# self.io_binding.bind_output(name.name)
|
||||||
|
# self.session.run_with_iobinding(self.io_binding, None)
|
||||||
|
# return self.io_binding.copy_outputs_to_cpu()
|
||||||
|
return self.session.run(None, inputs)
|
||||||
|
|
||||||
|
# compatability with diffusers load code
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
model_id: Union[str, Path],
|
||||||
|
subfolder: Union[str, Path] = None,
|
||||||
|
file_name: Optional[str] = None,
|
||||||
|
provider: Optional[str] = None,
|
||||||
|
sess_options: Optional["SessionOptions"] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
file_name = file_name or ONNX_WEIGHTS_NAME
|
||||||
|
|
||||||
|
if os.path.isdir(model_id):
|
||||||
|
model_path = model_id
|
||||||
|
if subfolder is not None:
|
||||||
|
model_path = os.path.join(model_path, subfolder)
|
||||||
|
model_path = os.path.join(model_path, file_name)
|
||||||
|
|
||||||
|
else:
|
||||||
|
model_path = model_id
|
||||||
|
|
||||||
|
# load model from local directory
|
||||||
|
if not os.path.isfile(model_path):
|
||||||
|
raise Exception(f"Model not found: {model_path}")
|
||||||
|
|
||||||
|
# TODO: session options
|
||||||
|
return cls(model_path, provider=provider)
|
||||||
|
@ -0,0 +1,157 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
from enum import Enum
|
||||||
|
from pydantic import Field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal, Optional, Union
|
||||||
|
from .base import (
|
||||||
|
ModelBase,
|
||||||
|
ModelConfigBase,
|
||||||
|
BaseModelType,
|
||||||
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
|
ModelVariantType,
|
||||||
|
DiffusersModel,
|
||||||
|
SchedulerPredictionType,
|
||||||
|
SilenceWarnings,
|
||||||
|
read_checkpoint_meta,
|
||||||
|
classproperty,
|
||||||
|
OnnxRuntimeModel,
|
||||||
|
IAIOnnxRuntimeModel,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionOnnxModelFormat(str, Enum):
|
||||||
|
Olive = "olive"
|
||||||
|
Onnx = "onnx"
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXStableDiffusion1Model(DiffusersModel):
|
||||||
|
class Config(ModelConfigBase):
|
||||||
|
model_format: Literal[StableDiffusionOnnxModelFormat.Onnx]
|
||||||
|
variant: ModelVariantType
|
||||||
|
|
||||||
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
assert base_model == BaseModelType.StableDiffusion1
|
||||||
|
assert model_type == ModelType.ONNX
|
||||||
|
super().__init__(
|
||||||
|
model_path=model_path,
|
||||||
|
base_model=BaseModelType.StableDiffusion1,
|
||||||
|
model_type=ModelType.ONNX,
|
||||||
|
)
|
||||||
|
|
||||||
|
for child_name, child_type in self.child_types.items():
|
||||||
|
if child_type is OnnxRuntimeModel:
|
||||||
|
self.child_types[child_name] = IAIOnnxRuntimeModel
|
||||||
|
|
||||||
|
# TODO: check that no optimum models provided
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def probe_config(cls, path: str, **kwargs):
|
||||||
|
model_format = cls.detect_format(path)
|
||||||
|
in_channels = 4 # TODO:
|
||||||
|
|
||||||
|
if in_channels == 9:
|
||||||
|
variant = ModelVariantType.Inpaint
|
||||||
|
elif in_channels == 4:
|
||||||
|
variant = ModelVariantType.Normal
|
||||||
|
else:
|
||||||
|
raise Exception("Unkown stable diffusion 1.* model format")
|
||||||
|
|
||||||
|
return cls.create_config(
|
||||||
|
path=path,
|
||||||
|
model_format=model_format,
|
||||||
|
variant=variant,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classproperty
|
||||||
|
def save_to_config(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def detect_format(cls, model_path: str):
|
||||||
|
# TODO: Detect onnx vs olive
|
||||||
|
return StableDiffusionOnnxModelFormat.Onnx
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_if_required(
|
||||||
|
cls,
|
||||||
|
model_path: str,
|
||||||
|
output_path: str,
|
||||||
|
config: ModelConfigBase,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
) -> str:
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXStableDiffusion2Model(DiffusersModel):
|
||||||
|
# TODO: check that configs overwriten properly
|
||||||
|
class Config(ModelConfigBase):
|
||||||
|
model_format: Literal[StableDiffusionOnnxModelFormat.Onnx]
|
||||||
|
variant: ModelVariantType
|
||||||
|
prediction_type: SchedulerPredictionType
|
||||||
|
upcast_attention: bool
|
||||||
|
|
||||||
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
assert base_model == BaseModelType.StableDiffusion2
|
||||||
|
assert model_type == ModelType.ONNX
|
||||||
|
super().__init__(
|
||||||
|
model_path=model_path,
|
||||||
|
base_model=BaseModelType.StableDiffusion2,
|
||||||
|
model_type=ModelType.ONNX,
|
||||||
|
)
|
||||||
|
|
||||||
|
for child_name, child_type in self.child_types.items():
|
||||||
|
if child_type is OnnxRuntimeModel:
|
||||||
|
self.child_types[child_name] = IAIOnnxRuntimeModel
|
||||||
|
# TODO: check that no optimum models provided
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def probe_config(cls, path: str, **kwargs):
|
||||||
|
model_format = cls.detect_format(path)
|
||||||
|
in_channels = 4 # TODO:
|
||||||
|
|
||||||
|
if in_channels == 9:
|
||||||
|
variant = ModelVariantType.Inpaint
|
||||||
|
elif in_channels == 5:
|
||||||
|
variant = ModelVariantType.Depth
|
||||||
|
elif in_channels == 4:
|
||||||
|
variant = ModelVariantType.Normal
|
||||||
|
else:
|
||||||
|
raise Exception("Unkown stable diffusion 2.* model format")
|
||||||
|
|
||||||
|
if variant == ModelVariantType.Normal:
|
||||||
|
prediction_type = SchedulerPredictionType.VPrediction
|
||||||
|
upcast_attention = True
|
||||||
|
|
||||||
|
else:
|
||||||
|
prediction_type = SchedulerPredictionType.Epsilon
|
||||||
|
upcast_attention = False
|
||||||
|
|
||||||
|
return cls.create_config(
|
||||||
|
path=path,
|
||||||
|
model_format=model_format,
|
||||||
|
variant=variant,
|
||||||
|
prediction_type=prediction_type,
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classproperty
|
||||||
|
def save_to_config(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def detect_format(cls, model_path: str):
|
||||||
|
# TODO: Detect onnx vs olive
|
||||||
|
return StableDiffusionOnnxModelFormat.Onnx
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_if_required(
|
||||||
|
cls,
|
||||||
|
model_path: str,
|
||||||
|
output_path: str,
|
||||||
|
config: ModelConfigBase,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
) -> str:
|
||||||
|
return model_path
|
169
invokeai/frontend/web/dist/assets/App-44cdaaf3.js
vendored
Normal file
169
invokeai/frontend/web/dist/assets/App-44cdaaf3.js
vendored
Normal file
File diff suppressed because one or more lines are too long
169
invokeai/frontend/web/dist/assets/App-ea7b7298.js
vendored
169
invokeai/frontend/web/dist/assets/App-ea7b7298.js
vendored
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -1,4 +1,4 @@
|
|||||||
import{A as m,f_ as Je,z as y,a4 as Ka,f$ as Xa,af as va,aj as d,g0 as b,g1 as t,g2 as Ya,g3 as h,g4 as ua,g5 as Ja,g6 as Qa,aI as Za,g7 as et,ad as rt,g8 as at}from"./index-9bb68e3a.js";import{s as fa,n as o,t as tt,o as ha,p as ot,q as ma,v as ga,w as ya,x as it,y as Sa,z as pa,A as xr,B as nt,D as lt,E as st,F as xa,G as $a,H as ka,J as dt,K as _a,L as ct,M as bt,N as vt,O as ut,Q as wa,R as ft,S as ht,T as mt,U as gt,V as yt,W as St,e as pt,X as xt}from"./MantineProvider-ae002ae6.js";var za=String.raw,Ca=za`
|
import{A as m,f$ as Je,z as y,a4 as Ka,g0 as Xa,af as va,aj as d,g1 as b,g2 as t,g3 as Ya,g4 as h,g5 as ua,g6 as Ja,g7 as Qa,aI as Za,g8 as et,ad as rt,g9 as at}from"./index-18f2f740.js";import{s as fa,n as o,t as tt,o as ha,p as ot,q as ma,v as ga,w as ya,x as it,y as Sa,z as pa,A as xr,B as nt,D as lt,E as st,F as xa,G as $a,H as ka,J as dt,K as _a,L as ct,M as bt,N as vt,O as ut,Q as wa,R as ft,S as ht,T as mt,U as gt,V as yt,W as St,e as pt,X as xt}from"./MantineProvider-b20a2267.js";var za=String.raw,Ca=za`
|
||||||
:root,
|
:root,
|
||||||
:host {
|
:host {
|
||||||
--chakra-vh: 100vh;
|
--chakra-vh: 100vh;
|
125
invokeai/frontend/web/dist/assets/index-18f2f740.js
vendored
Normal file
125
invokeai/frontend/web/dist/assets/index-18f2f740.js
vendored
Normal file
File diff suppressed because one or more lines are too long
125
invokeai/frontend/web/dist/assets/index-9bb68e3a.js
vendored
125
invokeai/frontend/web/dist/assets/index-9bb68e3a.js
vendored
File diff suppressed because one or more lines are too long
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@ -12,7 +12,7 @@
|
|||||||
margin: 0;
|
margin: 0;
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
<script type="module" crossorigin src="./assets/index-9bb68e3a.js"></script>
|
<script type="module" crossorigin src="./assets/index-18f2f740.js"></script>
|
||||||
</head>
|
</head>
|
||||||
|
|
||||||
<body dir="ltr">
|
<body dir="ltr">
|
||||||
|
2
invokeai/frontend/web/dist/locales/en.json
vendored
2
invokeai/frontend/web/dist/locales/en.json
vendored
@ -342,6 +342,8 @@
|
|||||||
"diffusersModels": "Diffusers",
|
"diffusersModels": "Diffusers",
|
||||||
"loraModels": "LoRAs",
|
"loraModels": "LoRAs",
|
||||||
"safetensorModels": "SafeTensors",
|
"safetensorModels": "SafeTensors",
|
||||||
|
"onnxModels": "Onnx",
|
||||||
|
"oliveModels": "Olives",
|
||||||
"modelAdded": "Model Added",
|
"modelAdded": "Model Added",
|
||||||
"modelUpdated": "Model Updated",
|
"modelUpdated": "Model Updated",
|
||||||
"modelUpdateFailed": "Model Update Failed",
|
"modelUpdateFailed": "Model Update Failed",
|
||||||
|
@ -342,6 +342,8 @@
|
|||||||
"diffusersModels": "Diffusers",
|
"diffusersModels": "Diffusers",
|
||||||
"loraModels": "LoRAs",
|
"loraModels": "LoRAs",
|
||||||
"safetensorModels": "SafeTensors",
|
"safetensorModels": "SafeTensors",
|
||||||
|
"onnxModels": "Onnx",
|
||||||
|
"oliveModels": "Olives",
|
||||||
"modelAdded": "Model Added",
|
"modelAdded": "Model Added",
|
||||||
"modelUpdated": "Model Updated",
|
"modelUpdated": "Model Updated",
|
||||||
"modelUpdateFailed": "Model Update Failed",
|
"modelUpdateFailed": "Model Update Failed",
|
||||||
|
@ -36,7 +36,8 @@ export const addModelsLoadedListener = () => {
|
|||||||
action.payload.entities,
|
action.payload.entities,
|
||||||
(m) =>
|
(m) =>
|
||||||
m?.model_name === currentModel?.model_name &&
|
m?.model_name === currentModel?.model_name &&
|
||||||
m?.base_model === currentModel?.base_model
|
m?.base_model === currentModel?.base_model &&
|
||||||
|
m?.model_type === currentModel?.model_type
|
||||||
);
|
);
|
||||||
|
|
||||||
if (isCurrentModelAvailable) {
|
if (isCurrentModelAvailable) {
|
||||||
@ -83,7 +84,8 @@ export const addModelsLoadedListener = () => {
|
|||||||
action.payload.entities,
|
action.payload.entities,
|
||||||
(m) =>
|
(m) =>
|
||||||
m?.model_name === currentModel?.model_name &&
|
m?.model_name === currentModel?.model_name &&
|
||||||
m?.base_model === currentModel?.base_model
|
m?.base_model === currentModel?.base_model &&
|
||||||
|
m?.model_type === currentModel?.model_type
|
||||||
);
|
);
|
||||||
|
|
||||||
if (isCurrentModelAvailable) {
|
if (isCurrentModelAvailable) {
|
||||||
|
@ -47,9 +47,9 @@ export const addTabChangedListener = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// only store the model name and base model in redux
|
// only store the model name and base model in redux
|
||||||
const { base_model, model_name } = firstValidCanvasModel;
|
const { base_model, model_name, model_type } = firstValidCanvasModel;
|
||||||
|
|
||||||
dispatch(modelChanged({ base_model, model_name }));
|
dispatch(modelChanged({ base_model, model_name, model_type }));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
@ -14,8 +14,11 @@ import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels
|
|||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import {
|
||||||
|
useGetMainModelsQuery,
|
||||||
|
useGetOnnxModelsQuery,
|
||||||
|
} from 'services/api/endpoints/models';
|
||||||
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
||||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
|
||||||
import { FieldComponentProps } from './types';
|
import { FieldComponentProps } from './types';
|
||||||
import { useFeatureStatus } from '../../../system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from '../../../system/hooks/useFeatureStatus';
|
||||||
|
|
||||||
@ -28,6 +31,7 @@ const ModelInputFieldComponent = (
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
||||||
|
|
||||||
|
const { data: onnxModels } = useGetOnnxModelsQuery(NON_REFINER_BASE_MODELS);
|
||||||
const { data: mainModels, isLoading } = useGetMainModelsQuery(
|
const { data: mainModels, isLoading } = useGetMainModelsQuery(
|
||||||
NON_REFINER_BASE_MODELS
|
NON_REFINER_BASE_MODELS
|
||||||
);
|
);
|
||||||
@ -51,17 +55,39 @@ const ModelInputFieldComponent = (
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if (onnxModels) {
|
||||||
|
forEach(onnxModels.entities, (model, id) => {
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
data.push({
|
||||||
|
value: id,
|
||||||
|
label: model.model_name,
|
||||||
|
group: MODEL_TYPE_MAP[model.base_model],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
return data;
|
return data;
|
||||||
}, [mainModels]);
|
}, [mainModels, onnxModels]);
|
||||||
|
|
||||||
// grab the full model entity from the RTK Query cache
|
// grab the full model entity from the RTK Query cache
|
||||||
// TODO: maybe we should just store the full model entity in state?
|
// TODO: maybe we should just store the full model entity in state?
|
||||||
const selectedModel = useMemo(
|
const selectedModel = useMemo(
|
||||||
() =>
|
() =>
|
||||||
mainModels?.entities[
|
(mainModels?.entities[
|
||||||
`${field.value?.base_model}/main/${field.value?.model_name}`
|
`${field.value?.base_model}/main/${field.value?.model_name}`
|
||||||
] ?? null,
|
] ||
|
||||||
[field.value?.base_model, field.value?.model_name, mainModels?.entities]
|
onnxModels?.entities[
|
||||||
|
`${field.value?.base_model}/onnx/${field.value?.model_name}`
|
||||||
|
]) ??
|
||||||
|
null,
|
||||||
|
[
|
||||||
|
field.value?.base_model,
|
||||||
|
field.value?.model_name,
|
||||||
|
mainModels?.entities,
|
||||||
|
onnxModels?.entities,
|
||||||
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleChangeModel = useCallback(
|
const handleChangeModel = useCallback(
|
||||||
|
@ -9,6 +9,7 @@ import {
|
|||||||
CLIP_SKIP,
|
CLIP_SKIP,
|
||||||
LORA_LOADER,
|
LORA_LOADER,
|
||||||
MAIN_MODEL_LOADER,
|
MAIN_MODEL_LOADER,
|
||||||
|
ONNX_MODEL_LOADER,
|
||||||
METADATA_ACCUMULATOR,
|
METADATA_ACCUMULATOR,
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
@ -17,7 +18,8 @@ import {
|
|||||||
export const addLoRAsToGraph = (
|
export const addLoRAsToGraph = (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
graph: NonNullableGraph,
|
graph: NonNullableGraph,
|
||||||
baseNodeId: string
|
baseNodeId: string,
|
||||||
|
modelLoader: string = MAIN_MODEL_LOADER
|
||||||
): void => {
|
): void => {
|
||||||
/**
|
/**
|
||||||
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
|
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
|
||||||
@ -40,6 +42,10 @@ export const addLoRAsToGraph = (
|
|||||||
!(
|
!(
|
||||||
e.source.node_id === MAIN_MODEL_LOADER &&
|
e.source.node_id === MAIN_MODEL_LOADER &&
|
||||||
['unet'].includes(e.source.field)
|
['unet'].includes(e.source.field)
|
||||||
|
) &&
|
||||||
|
!(
|
||||||
|
e.source.node_id === ONNX_MODEL_LOADER &&
|
||||||
|
['unet'].includes(e.source.field)
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
// Remove CLIP_SKIP connections to conditionings to feed it through LoRAs
|
// Remove CLIP_SKIP connections to conditionings to feed it through LoRAs
|
||||||
@ -75,12 +81,11 @@ export const addLoRAsToGraph = (
|
|||||||
|
|
||||||
// add to graph
|
// add to graph
|
||||||
graph.nodes[currentLoraNodeId] = loraLoaderNode;
|
graph.nodes[currentLoraNodeId] = loraLoaderNode;
|
||||||
|
|
||||||
if (currentLoraIndex === 0) {
|
if (currentLoraIndex === 0) {
|
||||||
// first lora = start the lora chain, attach directly to model loader
|
// first lora = start the lora chain, attach directly to model loader
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: {
|
source: {
|
||||||
node_id: MAIN_MODEL_LOADER,
|
node_id: modelLoader,
|
||||||
field: 'unet',
|
field: 'unet',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
|
@ -9,13 +9,15 @@ import {
|
|||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
MAIN_MODEL_LOADER,
|
MAIN_MODEL_LOADER,
|
||||||
METADATA_ACCUMULATOR,
|
METADATA_ACCUMULATOR,
|
||||||
|
ONNX_MODEL_LOADER,
|
||||||
TEXT_TO_IMAGE_GRAPH,
|
TEXT_TO_IMAGE_GRAPH,
|
||||||
VAE_LOADER,
|
VAE_LOADER,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
|
|
||||||
export const addVAEToGraph = (
|
export const addVAEToGraph = (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
graph: NonNullableGraph
|
graph: NonNullableGraph,
|
||||||
|
modelLoader: string = MAIN_MODEL_LOADER
|
||||||
): void => {
|
): void => {
|
||||||
const { vae } = state.generation;
|
const { vae } = state.generation;
|
||||||
|
|
||||||
@ -32,12 +34,12 @@ export const addVAEToGraph = (
|
|||||||
vae_model: vae,
|
vae_model: vae,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
const isOnnxModel = modelLoader == ONNX_MODEL_LOADER;
|
||||||
if (graph.id === TEXT_TO_IMAGE_GRAPH || graph.id === IMAGE_TO_IMAGE_GRAPH) {
|
if (graph.id === TEXT_TO_IMAGE_GRAPH || graph.id === IMAGE_TO_IMAGE_GRAPH) {
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: {
|
source: {
|
||||||
node_id: isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
|
node_id: isAutoVae ? modelLoader : VAE_LOADER,
|
||||||
field: 'vae',
|
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
node_id: LATENTS_TO_IMAGE,
|
node_id: LATENTS_TO_IMAGE,
|
||||||
@ -49,8 +51,8 @@ export const addVAEToGraph = (
|
|||||||
if (graph.id === IMAGE_TO_IMAGE_GRAPH) {
|
if (graph.id === IMAGE_TO_IMAGE_GRAPH) {
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: {
|
source: {
|
||||||
node_id: isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
|
node_id: isAutoVae ? modelLoader : VAE_LOADER,
|
||||||
field: 'vae',
|
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
node_id: IMAGE_TO_LATENTS,
|
node_id: IMAGE_TO_LATENTS,
|
||||||
@ -62,8 +64,8 @@ export const addVAEToGraph = (
|
|||||||
if (graph.id === INPAINT_GRAPH) {
|
if (graph.id === INPAINT_GRAPH) {
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: {
|
source: {
|
||||||
node_id: isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
|
node_id: isAutoVae ? modelLoader : VAE_LOADER,
|
||||||
field: 'vae',
|
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
node_id: INPAINT,
|
node_id: INPAINT,
|
||||||
|
@ -12,6 +12,7 @@ import {
|
|||||||
CLIP_SKIP,
|
CLIP_SKIP,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
MAIN_MODEL_LOADER,
|
MAIN_MODEL_LOADER,
|
||||||
|
ONNX_MODEL_LOADER,
|
||||||
METADATA_ACCUMULATOR,
|
METADATA_ACCUMULATOR,
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NOISE,
|
NOISE,
|
||||||
@ -52,7 +53,8 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
const use_cpu = shouldUseNoiseSettings
|
const use_cpu = shouldUseNoiseSettings
|
||||||
? shouldUseCpuNoise
|
? shouldUseCpuNoise
|
||||||
: initialGenerationState.shouldUseCpuNoise;
|
: initialGenerationState.shouldUseCpuNoise;
|
||||||
|
const onnx_model_type = model.model_type.includes('onnx');
|
||||||
|
const model_loader = onnx_model_type ? ONNX_MODEL_LOADER : MAIN_MODEL_LOADER;
|
||||||
/**
|
/**
|
||||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||||
@ -63,17 +65,18 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||||
|
// TODO: Actually create the graph correctly for ONNX
|
||||||
const graph: NonNullableGraph = {
|
const graph: NonNullableGraph = {
|
||||||
id: TEXT_TO_IMAGE_GRAPH,
|
id: TEXT_TO_IMAGE_GRAPH,
|
||||||
nodes: {
|
nodes: {
|
||||||
[POSITIVE_CONDITIONING]: {
|
[POSITIVE_CONDITIONING]: {
|
||||||
type: 'compel',
|
type: onnx_model_type ? 'prompt_onnx' : 'compel',
|
||||||
id: POSITIVE_CONDITIONING,
|
id: POSITIVE_CONDITIONING,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
prompt: positivePrompt,
|
prompt: positivePrompt,
|
||||||
},
|
},
|
||||||
[NEGATIVE_CONDITIONING]: {
|
[NEGATIVE_CONDITIONING]: {
|
||||||
type: 'compel',
|
type: onnx_model_type ? 'prompt_onnx' : 'compel',
|
||||||
id: NEGATIVE_CONDITIONING,
|
id: NEGATIVE_CONDITIONING,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
prompt: negativePrompt,
|
prompt: negativePrompt,
|
||||||
@ -87,16 +90,16 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
use_cpu,
|
use_cpu,
|
||||||
},
|
},
|
||||||
[TEXT_TO_LATENTS]: {
|
[TEXT_TO_LATENTS]: {
|
||||||
type: 't2l',
|
type: onnx_model_type ? 't2l_onnx' : 't2l',
|
||||||
id: TEXT_TO_LATENTS,
|
id: TEXT_TO_LATENTS,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
},
|
},
|
||||||
[MAIN_MODEL_LOADER]: {
|
[model_loader]: {
|
||||||
type: 'main_model_loader',
|
type: model_loader,
|
||||||
id: MAIN_MODEL_LOADER,
|
id: model_loader,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
model,
|
model,
|
||||||
},
|
},
|
||||||
@ -107,7 +110,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
skipped_layers: clipSkip,
|
skipped_layers: clipSkip,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
type: 'l2i',
|
type: onnx_model_type ? 'l2i_onnx' : 'l2i',
|
||||||
id: LATENTS_TO_IMAGE,
|
id: LATENTS_TO_IMAGE,
|
||||||
is_intermediate: !shouldAutoSave,
|
is_intermediate: !shouldAutoSave,
|
||||||
},
|
},
|
||||||
@ -135,7 +138,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MAIN_MODEL_LOADER,
|
node_id: model_loader,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -165,7 +168,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MAIN_MODEL_LOADER,
|
node_id: model_loader,
|
||||||
field: 'unet',
|
field: 'unet',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -229,10 +232,10 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
});
|
});
|
||||||
|
|
||||||
// add LoRA support
|
// add LoRA support
|
||||||
addLoRAsToGraph(state, graph, TEXT_TO_LATENTS);
|
addLoRAsToGraph(state, graph, TEXT_TO_LATENTS, model_loader);
|
||||||
|
|
||||||
// optionally add custom VAE
|
// optionally add custom VAE
|
||||||
addVAEToGraph(state, graph);
|
addVAEToGraph(state, graph, model_loader);
|
||||||
|
|
||||||
// add dynamic prompts - also sets up core iteration and seed
|
// add dynamic prompts - also sets up core iteration and seed
|
||||||
addDynamicPromptsToGraph(state, graph);
|
addDynamicPromptsToGraph(state, graph);
|
||||||
|
@ -12,6 +12,7 @@ import {
|
|||||||
CLIP_SKIP,
|
CLIP_SKIP,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
MAIN_MODEL_LOADER,
|
MAIN_MODEL_LOADER,
|
||||||
|
ONNX_MODEL_LOADER,
|
||||||
METADATA_ACCUMULATOR,
|
METADATA_ACCUMULATOR,
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NOISE,
|
NOISE,
|
||||||
@ -48,6 +49,8 @@ export const buildLinearTextToImageGraph = (
|
|||||||
throw new Error('No model found in state');
|
throw new Error('No model found in state');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const onnx_model_type = model.model_type.includes('onnx');
|
||||||
|
const model_loader = onnx_model_type ? ONNX_MODEL_LOADER : MAIN_MODEL_LOADER;
|
||||||
/**
|
/**
|
||||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||||
@ -58,12 +61,14 @@ export const buildLinearTextToImageGraph = (
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||||
|
|
||||||
|
// TODO: Actually create the graph correctly for ONNX
|
||||||
const graph: NonNullableGraph = {
|
const graph: NonNullableGraph = {
|
||||||
id: TEXT_TO_IMAGE_GRAPH,
|
id: TEXT_TO_IMAGE_GRAPH,
|
||||||
nodes: {
|
nodes: {
|
||||||
[MAIN_MODEL_LOADER]: {
|
[model_loader]: {
|
||||||
type: 'main_model_loader',
|
type: model_loader,
|
||||||
id: MAIN_MODEL_LOADER,
|
id: model_loader,
|
||||||
model,
|
model,
|
||||||
},
|
},
|
||||||
[CLIP_SKIP]: {
|
[CLIP_SKIP]: {
|
||||||
@ -72,12 +77,12 @@ export const buildLinearTextToImageGraph = (
|
|||||||
skipped_layers: clipSkip,
|
skipped_layers: clipSkip,
|
||||||
},
|
},
|
||||||
[POSITIVE_CONDITIONING]: {
|
[POSITIVE_CONDITIONING]: {
|
||||||
type: 'compel',
|
type: onnx_model_type ? 'prompt_onnx' : 'compel',
|
||||||
id: POSITIVE_CONDITIONING,
|
id: POSITIVE_CONDITIONING,
|
||||||
prompt: positivePrompt,
|
prompt: positivePrompt,
|
||||||
},
|
},
|
||||||
[NEGATIVE_CONDITIONING]: {
|
[NEGATIVE_CONDITIONING]: {
|
||||||
type: 'compel',
|
type: onnx_model_type ? 'prompt_onnx' : 'compel',
|
||||||
id: NEGATIVE_CONDITIONING,
|
id: NEGATIVE_CONDITIONING,
|
||||||
prompt: negativePrompt,
|
prompt: negativePrompt,
|
||||||
},
|
},
|
||||||
@ -89,14 +94,14 @@ export const buildLinearTextToImageGraph = (
|
|||||||
use_cpu,
|
use_cpu,
|
||||||
},
|
},
|
||||||
[TEXT_TO_LATENTS]: {
|
[TEXT_TO_LATENTS]: {
|
||||||
type: 't2l',
|
type: onnx_model_type ? 't2l_onnx' : 't2l',
|
||||||
id: TEXT_TO_LATENTS,
|
id: TEXT_TO_LATENTS,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
type: 'l2i',
|
type: onnx_model_type ? 'l2i_onnx' : 'l2i',
|
||||||
id: LATENTS_TO_IMAGE,
|
id: LATENTS_TO_IMAGE,
|
||||||
fp32: vaePrecision === 'fp32' ? true : false,
|
fp32: vaePrecision === 'fp32' ? true : false,
|
||||||
},
|
},
|
||||||
@ -104,7 +109,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
edges: [
|
edges: [
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MAIN_MODEL_LOADER,
|
node_id: model_loader,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -114,7 +119,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MAIN_MODEL_LOADER,
|
node_id: model_loader,
|
||||||
field: 'unet',
|
field: 'unet',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -218,10 +223,10 @@ export const buildLinearTextToImageGraph = (
|
|||||||
});
|
});
|
||||||
|
|
||||||
// add LoRA support
|
// add LoRA support
|
||||||
addLoRAsToGraph(state, graph, TEXT_TO_LATENTS);
|
addLoRAsToGraph(state, graph, TEXT_TO_LATENTS, model_loader);
|
||||||
|
|
||||||
// optionally add custom VAE
|
// optionally add custom VAE
|
||||||
addVAEToGraph(state, graph);
|
addVAEToGraph(state, graph, model_loader);
|
||||||
|
|
||||||
// add dynamic prompts - also sets up core iteration and seed
|
// add dynamic prompts - also sets up core iteration and seed
|
||||||
addDynamicPromptsToGraph(state, graph);
|
addDynamicPromptsToGraph(state, graph);
|
||||||
|
@ -10,6 +10,7 @@ export const RANDOM_INT = 'rand_int';
|
|||||||
export const RANGE_OF_SIZE = 'range_of_size';
|
export const RANGE_OF_SIZE = 'range_of_size';
|
||||||
export const ITERATE = 'iterate';
|
export const ITERATE = 'iterate';
|
||||||
export const MAIN_MODEL_LOADER = 'main_model_loader';
|
export const MAIN_MODEL_LOADER = 'main_model_loader';
|
||||||
|
export const ONNX_MODEL_LOADER = 'onnx_model_loader';
|
||||||
export const VAE_LOADER = 'vae_loader';
|
export const VAE_LOADER = 'vae_loader';
|
||||||
export const LORA_LOADER = 'lora_loader';
|
export const LORA_LOADER = 'lora_loader';
|
||||||
export const CLIP_SKIP = 'clip_skip';
|
export const CLIP_SKIP = 'clip_skip';
|
||||||
|
@ -15,8 +15,11 @@ import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainM
|
|||||||
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
|
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
|
import {
|
||||||
|
useGetMainModelsQuery,
|
||||||
|
useGetOnnxModelsQuery,
|
||||||
|
} from 'services/api/endpoints/models';
|
||||||
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
||||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
|
||||||
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
@ -35,6 +38,9 @@ const ParamMainModelSelect = () => {
|
|||||||
const { data: mainModels, isLoading } = useGetMainModelsQuery(
|
const { data: mainModels, isLoading } = useGetMainModelsQuery(
|
||||||
NON_REFINER_BASE_MODELS
|
NON_REFINER_BASE_MODELS
|
||||||
);
|
);
|
||||||
|
const { data: onnxModels, isLoading: onnxLoading } = useGetOnnxModelsQuery(
|
||||||
|
NON_REFINER_BASE_MODELS
|
||||||
|
);
|
||||||
|
|
||||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||||
|
|
||||||
@ -59,17 +65,35 @@ const ParamMainModelSelect = () => {
|
|||||||
group: MODEL_TYPE_MAP[model.base_model],
|
group: MODEL_TYPE_MAP[model.base_model],
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
forEach(onnxModels?.entities, (model, id) => {
|
||||||
|
if (
|
||||||
|
!model ||
|
||||||
|
activeTabName === 'unifiedCanvas' ||
|
||||||
|
activeTabName === 'img2img'
|
||||||
|
) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
data.push({
|
||||||
|
value: id,
|
||||||
|
label: model.model_name,
|
||||||
|
group: MODEL_TYPE_MAP[model.base_model],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
return data;
|
return data;
|
||||||
}, [mainModels, activeTabName]);
|
}, [mainModels, onnxModels, activeTabName]);
|
||||||
|
|
||||||
// grab the full model entity from the RTK Query cache
|
// grab the full model entity from the RTK Query cache
|
||||||
// TODO: maybe we should just store the full model entity in state?
|
// TODO: maybe we should just store the full model entity in state?
|
||||||
const selectedModel = useMemo(
|
const selectedModel = useMemo(
|
||||||
() =>
|
() =>
|
||||||
mainModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ??
|
(mainModels?.entities[`${model?.base_model}/main/${model?.model_name}`] ||
|
||||||
|
onnxModels?.entities[
|
||||||
|
`${model?.base_model}/onnx/${model?.model_name}`
|
||||||
|
]) ??
|
||||||
null,
|
null,
|
||||||
[mainModels?.entities, model]
|
[mainModels?.entities, model, onnxModels?.entities]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleChangeModel = useCallback(
|
const handleChangeModel = useCallback(
|
||||||
@ -89,7 +113,7 @@ const ParamMainModelSelect = () => {
|
|||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
return isLoading ? (
|
return isLoading || onnxLoading ? (
|
||||||
<IAIMantineSearchableSelect
|
<IAIMantineSearchableSelect
|
||||||
label={t('modelManager.model')}
|
label={t('modelManager.model')}
|
||||||
placeholder="Loading..."
|
placeholder="Loading..."
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import { createAction } from '@reduxjs/toolkit';
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
import { ImageDTO, MainModelField } from 'services/api/types';
|
import { ImageDTO, MainModelField, OnnxModelField } from 'services/api/types';
|
||||||
|
|
||||||
export const initialImageSelected = createAction<ImageDTO | undefined>(
|
export const initialImageSelected = createAction<ImageDTO | undefined>(
|
||||||
'generation/initialImageSelected'
|
'generation/initialImageSelected'
|
||||||
);
|
);
|
||||||
|
|
||||||
export const modelSelected = createAction<MainModelField>(
|
export const modelSelected = createAction<MainModelField | OnnxModelField>(
|
||||||
'generation/modelSelected'
|
'generation/modelSelected'
|
||||||
);
|
);
|
||||||
|
@ -3,7 +3,7 @@ import { createSlice } from '@reduxjs/toolkit';
|
|||||||
import { roundToMultiple } from 'common/util/roundDownToMultiple';
|
import { roundToMultiple } from 'common/util/roundDownToMultiple';
|
||||||
import { configChanged } from 'features/system/store/configSlice';
|
import { configChanged } from 'features/system/store/configSlice';
|
||||||
import { clamp } from 'lodash-es';
|
import { clamp } from 'lodash-es';
|
||||||
import { ImageDTO, MainModelField } from 'services/api/types';
|
import { ImageDTO, MainModelField, OnnxModelField } from 'services/api/types';
|
||||||
import { clipSkipMap } from '../types/constants';
|
import { clipSkipMap } from '../types/constants';
|
||||||
import {
|
import {
|
||||||
CfgScaleParam,
|
CfgScaleParam,
|
||||||
@ -50,7 +50,7 @@ export interface GenerationState {
|
|||||||
shouldUseSymmetry: boolean;
|
shouldUseSymmetry: boolean;
|
||||||
horizontalSymmetrySteps: number;
|
horizontalSymmetrySteps: number;
|
||||||
verticalSymmetrySteps: number;
|
verticalSymmetrySteps: number;
|
||||||
model: MainModelField | null;
|
model: MainModelField | OnnxModelField | null;
|
||||||
vae: VaeModelParam | null;
|
vae: VaeModelParam | null;
|
||||||
vaePrecision: PrecisionParam;
|
vaePrecision: PrecisionParam;
|
||||||
seamlessXAxis: boolean;
|
seamlessXAxis: boolean;
|
||||||
@ -272,11 +272,12 @@ export const generationSlice = createSlice({
|
|||||||
const defaultModel = action.payload.sd?.defaultModel;
|
const defaultModel = action.payload.sd?.defaultModel;
|
||||||
|
|
||||||
if (defaultModel && !state.model) {
|
if (defaultModel && !state.model) {
|
||||||
const [base_model, _model_type, model_name] = defaultModel.split('/');
|
const [base_model, model_type, model_name] = defaultModel.split('/');
|
||||||
|
|
||||||
const result = zMainModel.safeParse({
|
const result = zMainModel.safeParse({
|
||||||
model_name,
|
model_name,
|
||||||
base_model,
|
base_model,
|
||||||
|
model_type,
|
||||||
});
|
});
|
||||||
|
|
||||||
if (result.success) {
|
if (result.success) {
|
||||||
|
@ -210,6 +210,14 @@ export type HeightParam = z.infer<typeof zHeight>;
|
|||||||
export const isValidHeight = (val: unknown): val is HeightParam =>
|
export const isValidHeight = (val: unknown): val is HeightParam =>
|
||||||
zHeight.safeParse(val).success;
|
zHeight.safeParse(val).success;
|
||||||
|
|
||||||
|
const zModelType = z.enum([
|
||||||
|
'vae',
|
||||||
|
'lora',
|
||||||
|
'onnx',
|
||||||
|
'main',
|
||||||
|
'controlnet',
|
||||||
|
'embedding',
|
||||||
|
]);
|
||||||
const zBaseModel = z.enum(['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
const zBaseModel = z.enum(['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
||||||
|
|
||||||
export type BaseModelParam = z.infer<typeof zBaseModel>;
|
export type BaseModelParam = z.infer<typeof zBaseModel>;
|
||||||
@ -221,12 +229,18 @@ export type BaseModelParam = z.infer<typeof zBaseModel>;
|
|||||||
export const zMainModel = z.object({
|
export const zMainModel = z.object({
|
||||||
model_name: z.string().min(1),
|
model_name: z.string().min(1),
|
||||||
base_model: zBaseModel,
|
base_model: zBaseModel,
|
||||||
|
model_type: zModelType,
|
||||||
});
|
});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Type alias for model parameter, inferred from its zod schema
|
* Type alias for model parameter, inferred from its zod schema
|
||||||
*/
|
*/
|
||||||
export type MainModelParam = z.infer<typeof zMainModel>;
|
export type MainModelParam = z.infer<typeof zMainModel>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Type alias for model parameter, inferred from its zod schema
|
||||||
|
*/
|
||||||
|
export type OnnxModelParam = z.infer<typeof zMainModel>;
|
||||||
/**
|
/**
|
||||||
* Validates/type-guards a value as a model parameter
|
* Validates/type-guards a value as a model parameter
|
||||||
*/
|
*/
|
||||||
|
@ -8,11 +8,12 @@ export const modelIdToMainModelParam = (
|
|||||||
mainModelId: string
|
mainModelId: string
|
||||||
): MainModelParam | undefined => {
|
): MainModelParam | undefined => {
|
||||||
const log = logger('models');
|
const log = logger('models');
|
||||||
const [base_model, _model_type, model_name] = mainModelId.split('/');
|
const [base_model, model_type, model_name] = mainModelId.split('/');
|
||||||
|
|
||||||
const result = zMainModel.safeParse({
|
const result = zMainModel.safeParse({
|
||||||
base_model,
|
base_model,
|
||||||
model_name,
|
model_name,
|
||||||
|
model_type,
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!result.success) {
|
if (!result.success) {
|
||||||
|
@ -8,7 +8,9 @@ import { useCallback, useState } from 'react';
|
|||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import {
|
import {
|
||||||
MainModelConfigEntity,
|
MainModelConfigEntity,
|
||||||
|
OnnxModelConfigEntity,
|
||||||
useGetMainModelsQuery,
|
useGetMainModelsQuery,
|
||||||
|
useGetOnnxModelsQuery,
|
||||||
useGetLoRAModelsQuery,
|
useGetLoRAModelsQuery,
|
||||||
LoRAModelConfigEntity,
|
LoRAModelConfigEntity,
|
||||||
} from 'services/api/endpoints/models';
|
} from 'services/api/endpoints/models';
|
||||||
@ -20,9 +22,9 @@ type ModelListProps = {
|
|||||||
setSelectedModelId: (name: string | undefined) => void;
|
setSelectedModelId: (name: string | undefined) => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
type ModelFormat = 'images' | 'checkpoint' | 'diffusers';
|
type ModelFormat = 'images' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
|
||||||
|
|
||||||
type ModelType = 'main' | 'lora';
|
type ModelType = 'main' | 'lora' | 'onnx';
|
||||||
|
|
||||||
type CombinedModelFormat = ModelFormat | 'lora';
|
type CombinedModelFormat = ModelFormat | 'lora';
|
||||||
|
|
||||||
@ -61,6 +63,18 @@ const ModelList = (props: ModelListProps) => {
|
|||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const { filteredOnnxModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
|
||||||
|
selectFromResult: ({ data }) => ({
|
||||||
|
filteredOnnxModels: modelsFilter(data, 'onnx', 'onnx', nameFilter),
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
const { filteredOliveModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
|
||||||
|
selectFromResult: ({ data }) => ({
|
||||||
|
filteredOliveModels: modelsFilter(data, 'onnx', 'olive', nameFilter),
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||||
setNameFilter(e.target.value);
|
setNameFilter(e.target.value);
|
||||||
}, []);
|
}, []);
|
||||||
@ -85,10 +99,17 @@ const ModelList = (props: ModelListProps) => {
|
|||||||
</IAIButton>
|
</IAIButton>
|
||||||
<IAIButton
|
<IAIButton
|
||||||
size="sm"
|
size="sm"
|
||||||
onClick={() => setModelFormatFilter('checkpoint')}
|
onClick={() => setModelFormatFilter('onnx')}
|
||||||
isChecked={modelFormatFilter === 'checkpoint'}
|
isChecked={modelFormatFilter === 'onnx'}
|
||||||
>
|
>
|
||||||
{t('modelManager.checkpointModels')}
|
{t('modelManager.onnxModels')}
|
||||||
|
</IAIButton>
|
||||||
|
<IAIButton
|
||||||
|
size="sm"
|
||||||
|
onClick={() => setModelFormatFilter('olive')}
|
||||||
|
isChecked={modelFormatFilter === 'olive'}
|
||||||
|
>
|
||||||
|
{t('modelManager.oliveModels')}
|
||||||
</IAIButton>
|
</IAIButton>
|
||||||
<IAIButton
|
<IAIButton
|
||||||
size="sm"
|
size="sm"
|
||||||
@ -147,6 +168,42 @@ const ModelList = (props: ModelListProps) => {
|
|||||||
</Flex>
|
</Flex>
|
||||||
</StyledModelContainer>
|
</StyledModelContainer>
|
||||||
)}
|
)}
|
||||||
|
{['images', 'olive'].includes(modelFormatFilter) &&
|
||||||
|
filteredOliveModels.length > 0 && (
|
||||||
|
<StyledModelContainer>
|
||||||
|
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||||
|
<Text variant="subtext" fontSize="sm">
|
||||||
|
Olives
|
||||||
|
</Text>
|
||||||
|
{filteredOliveModels.map((model) => (
|
||||||
|
<ModelListItem
|
||||||
|
key={model.id}
|
||||||
|
model={model}
|
||||||
|
isSelected={selectedModelId === model.id}
|
||||||
|
setSelectedModelId={setSelectedModelId}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</Flex>
|
||||||
|
</StyledModelContainer>
|
||||||
|
)}
|
||||||
|
{['images', 'onnx'].includes(modelFormatFilter) &&
|
||||||
|
filteredOnnxModels.length > 0 && (
|
||||||
|
<StyledModelContainer>
|
||||||
|
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||||
|
<Text variant="subtext" fontSize="sm">
|
||||||
|
Onnx
|
||||||
|
</Text>
|
||||||
|
{filteredOnnxModels.map((model) => (
|
||||||
|
<ModelListItem
|
||||||
|
key={model.id}
|
||||||
|
model={model}
|
||||||
|
isSelected={selectedModelId === model.id}
|
||||||
|
setSelectedModelId={setSelectedModelId}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</Flex>
|
||||||
|
</StyledModelContainer>
|
||||||
|
)}
|
||||||
{['images', 'lora'].includes(modelFormatFilter) &&
|
{['images', 'lora'].includes(modelFormatFilter) &&
|
||||||
filteredLoraModels.length > 0 && (
|
filteredLoraModels.length > 0 && (
|
||||||
<StyledModelContainer>
|
<StyledModelContainer>
|
||||||
@ -173,7 +230,12 @@ const ModelList = (props: ModelListProps) => {
|
|||||||
|
|
||||||
export default ModelList;
|
export default ModelList;
|
||||||
|
|
||||||
const modelsFilter = <T extends MainModelConfigEntity | LoRAModelConfigEntity>(
|
const modelsFilter = <
|
||||||
|
T extends
|
||||||
|
| MainModelConfigEntity
|
||||||
|
| LoRAModelConfigEntity
|
||||||
|
| OnnxModelConfigEntity
|
||||||
|
>(
|
||||||
data: EntityState<T> | undefined,
|
data: EntityState<T> | undefined,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
model_format: ModelFormat | undefined,
|
model_format: ModelFormat | undefined,
|
||||||
|
@ -10,9 +10,11 @@ import {
|
|||||||
ImportModelConfig,
|
ImportModelConfig,
|
||||||
LoRAModelConfig,
|
LoRAModelConfig,
|
||||||
MainModelConfig,
|
MainModelConfig,
|
||||||
|
OnnxModelConfig,
|
||||||
MergeModelConfig,
|
MergeModelConfig,
|
||||||
TextualInversionModelConfig,
|
TextualInversionModelConfig,
|
||||||
VaeModelConfig,
|
VaeModelConfig,
|
||||||
|
ModelType,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
|
|
||||||
import queryString from 'query-string';
|
import queryString from 'query-string';
|
||||||
@ -27,6 +29,8 @@ export type MainModelConfigEntity =
|
|||||||
| DiffusersModelConfigEntity
|
| DiffusersModelConfigEntity
|
||||||
| CheckpointModelConfigEntity;
|
| CheckpointModelConfigEntity;
|
||||||
|
|
||||||
|
export type OnnxModelConfigEntity = OnnxModelConfig & { id: string };
|
||||||
|
|
||||||
export type LoRAModelConfigEntity = LoRAModelConfig & { id: string };
|
export type LoRAModelConfigEntity = LoRAModelConfig & { id: string };
|
||||||
|
|
||||||
export type ControlNetModelConfigEntity = ControlNetModelConfig & {
|
export type ControlNetModelConfigEntity = ControlNetModelConfig & {
|
||||||
@ -41,6 +45,7 @@ export type VaeModelConfigEntity = VaeModelConfig & { id: string };
|
|||||||
|
|
||||||
type AnyModelConfigEntity =
|
type AnyModelConfigEntity =
|
||||||
| MainModelConfigEntity
|
| MainModelConfigEntity
|
||||||
|
| OnnxModelConfigEntity
|
||||||
| LoRAModelConfigEntity
|
| LoRAModelConfigEntity
|
||||||
| ControlNetModelConfigEntity
|
| ControlNetModelConfigEntity
|
||||||
| TextualInversionModelConfigEntity
|
| TextualInversionModelConfigEntity
|
||||||
@ -66,6 +71,7 @@ type UpdateLoRAModelResponse = UpdateMainModelResponse;
|
|||||||
type DeleteMainModelArg = {
|
type DeleteMainModelArg = {
|
||||||
base_model: BaseModelType;
|
base_model: BaseModelType;
|
||||||
model_name: string;
|
model_name: string;
|
||||||
|
model_type: ModelType;
|
||||||
};
|
};
|
||||||
|
|
||||||
type DeleteMainModelResponse = void;
|
type DeleteMainModelResponse = void;
|
||||||
@ -119,6 +125,10 @@ type SearchFolderArg = operations['search_for_models']['parameters']['query'];
|
|||||||
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
||||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const onnxModelsAdapter = createEntityAdapter<OnnxModelConfigEntity>({
|
||||||
|
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||||
|
});
|
||||||
const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
|
const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
|
||||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||||
});
|
});
|
||||||
@ -156,6 +166,49 @@ const createModelEntities = <T extends AnyModelConfigEntity>(
|
|||||||
|
|
||||||
export const modelsApi = api.injectEndpoints({
|
export const modelsApi = api.injectEndpoints({
|
||||||
endpoints: (build) => ({
|
endpoints: (build) => ({
|
||||||
|
getOnnxModels: build.query<
|
||||||
|
EntityState<OnnxModelConfigEntity>,
|
||||||
|
BaseModelType[]
|
||||||
|
>({
|
||||||
|
query: (base_models) => {
|
||||||
|
const params = {
|
||||||
|
model_type: 'onnx',
|
||||||
|
base_models,
|
||||||
|
};
|
||||||
|
|
||||||
|
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
||||||
|
return `models/?${query}`;
|
||||||
|
},
|
||||||
|
providesTags: (result, error, arg) => {
|
||||||
|
const tags: ApiFullTagDescription[] = [
|
||||||
|
{ id: 'OnnxModel', type: LIST_TAG },
|
||||||
|
];
|
||||||
|
|
||||||
|
if (result) {
|
||||||
|
tags.push(
|
||||||
|
...result.ids.map((id) => ({
|
||||||
|
type: 'OnnxModel' as const,
|
||||||
|
id,
|
||||||
|
}))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return tags;
|
||||||
|
},
|
||||||
|
transformResponse: (
|
||||||
|
response: { models: OnnxModelConfig[] },
|
||||||
|
meta,
|
||||||
|
arg
|
||||||
|
) => {
|
||||||
|
const entities = createModelEntities<OnnxModelConfigEntity>(
|
||||||
|
response.models
|
||||||
|
);
|
||||||
|
return onnxModelsAdapter.setAll(
|
||||||
|
onnxModelsAdapter.getInitialState(),
|
||||||
|
entities
|
||||||
|
);
|
||||||
|
},
|
||||||
|
}),
|
||||||
getMainModels: build.query<
|
getMainModels: build.query<
|
||||||
EntityState<MainModelConfigEntity>,
|
EntityState<MainModelConfigEntity>,
|
||||||
BaseModelType[]
|
BaseModelType[]
|
||||||
@ -248,9 +301,9 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
DeleteMainModelResponse,
|
DeleteMainModelResponse,
|
||||||
DeleteMainModelArg
|
DeleteMainModelArg
|
||||||
>({
|
>({
|
||||||
query: ({ base_model, model_name }) => {
|
query: ({ base_model, model_name, model_type }) => {
|
||||||
return {
|
return {
|
||||||
url: `models/${base_model}/main/${model_name}`,
|
url: `models/${base_model}/${model_type}/${model_name}`,
|
||||||
method: 'DELETE',
|
method: 'DELETE',
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
@ -494,6 +547,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
|
|
||||||
export const {
|
export const {
|
||||||
useGetMainModelsQuery,
|
useGetMainModelsQuery,
|
||||||
|
useGetOnnxModelsQuery,
|
||||||
useGetControlNetModelsQuery,
|
useGetControlNetModelsQuery,
|
||||||
useGetLoRAModelsQuery,
|
useGetLoRAModelsQuery,
|
||||||
useGetTextualInversionModelsQuery,
|
useGetTextualInversionModelsQuery,
|
||||||
|
360
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
360
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
@ -1381,7 +1381,7 @@ export type components = {
|
|||||||
* @description The nodes in this graph
|
* @description The nodes in this graph
|
||||||
*/
|
*/
|
||||||
nodes?: {
|
nodes?: {
|
||||||
[key: string]: (components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
|
[key: string]: (components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* Edges
|
* Edges
|
||||||
@ -1424,7 +1424,7 @@ export type components = {
|
|||||||
* @description The results of node executions
|
* @description The results of node executions
|
||||||
*/
|
*/
|
||||||
results: {
|
results: {
|
||||||
[key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["MetadataAccumulatorOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
|
[key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["MetadataAccumulatorOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* Errors
|
* Errors
|
||||||
@ -2562,10 +2562,10 @@ export type components = {
|
|||||||
/**
|
/**
|
||||||
* Infill Method
|
* Infill Method
|
||||||
* @description The method used to infill empty regions (px)
|
* @description The method used to infill empty regions (px)
|
||||||
* @default patchmatch
|
* @default tile
|
||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
infill_method?: "patchmatch" | "tile" | "solid";
|
infill_method?: "tile" | "solid";
|
||||||
/**
|
/**
|
||||||
* Inpaint Width
|
* Inpaint Width
|
||||||
* @description The width of the inpaint region (px)
|
* @description The width of the inpaint region (px)
|
||||||
@ -3173,6 +3173,8 @@ export type components = {
|
|||||||
model_name: string;
|
model_name: string;
|
||||||
/** @description Base model */
|
/** @description Base model */
|
||||||
base_model: components["schemas"]["BaseModelType"];
|
base_model: components["schemas"]["BaseModelType"];
|
||||||
|
/** @description Model Type */
|
||||||
|
model_type: components["schemas"]["ModelType"];
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* MainModelLoaderInvocation
|
* MainModelLoaderInvocation
|
||||||
@ -3618,7 +3620,7 @@ export type components = {
|
|||||||
* @description An enumeration.
|
* @description An enumeration.
|
||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
ModelType: "main" | "vae" | "lora" | "controlnet" | "embedding";
|
ModelType: "onnx" | "main" | "vae" | "lora" | "controlnet" | "embedding";
|
||||||
/**
|
/**
|
||||||
* ModelVariantType
|
* ModelVariantType
|
||||||
* @description An enumeration.
|
* @description An enumeration.
|
||||||
@ -3628,7 +3630,7 @@ export type components = {
|
|||||||
/** ModelsList */
|
/** ModelsList */
|
||||||
ModelsList: {
|
ModelsList: {
|
||||||
/** Models */
|
/** Models */
|
||||||
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"])[];
|
models: (components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"])[];
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* MultiplyInvocation
|
* MultiplyInvocation
|
||||||
@ -3778,6 +3780,261 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
image_resolution?: number;
|
image_resolution?: number;
|
||||||
};
|
};
|
||||||
|
/**
|
||||||
|
* ONNXLatentsToImageInvocation
|
||||||
|
* @description Generates an image from latents.
|
||||||
|
*/
|
||||||
|
ONNXLatentsToImageInvocation: {
|
||||||
|
/**
|
||||||
|
* Id
|
||||||
|
* @description The id of this node. Must be unique among all nodes.
|
||||||
|
*/
|
||||||
|
id: string;
|
||||||
|
/**
|
||||||
|
* Is Intermediate
|
||||||
|
* @description Whether or not this node is an intermediate node.
|
||||||
|
* @default false
|
||||||
|
*/
|
||||||
|
is_intermediate?: boolean;
|
||||||
|
/**
|
||||||
|
* Type
|
||||||
|
* @default l2i_onnx
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
type?: "l2i_onnx";
|
||||||
|
/**
|
||||||
|
* Latents
|
||||||
|
* @description The latents to generate an image from
|
||||||
|
*/
|
||||||
|
latents?: components["schemas"]["LatentsField"];
|
||||||
|
/**
|
||||||
|
* Vae
|
||||||
|
* @description Vae submodel
|
||||||
|
*/
|
||||||
|
vae?: components["schemas"]["VaeField"];
|
||||||
|
/**
|
||||||
|
* Metadata
|
||||||
|
* @description Optional core metadata to be written to the image
|
||||||
|
*/
|
||||||
|
metadata?: components["schemas"]["CoreMetadata"];
|
||||||
|
};
|
||||||
|
/**
|
||||||
|
* ONNXModelLoaderOutput
|
||||||
|
* @description Model loader output
|
||||||
|
*/
|
||||||
|
ONNXModelLoaderOutput: {
|
||||||
|
/**
|
||||||
|
* Type
|
||||||
|
* @default model_loader_output_onnx
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
type?: "model_loader_output_onnx";
|
||||||
|
/**
|
||||||
|
* Unet
|
||||||
|
* @description UNet submodel
|
||||||
|
*/
|
||||||
|
unet?: components["schemas"]["UNetField"];
|
||||||
|
/**
|
||||||
|
* Clip
|
||||||
|
* @description Tokenizer and text_encoder submodels
|
||||||
|
*/
|
||||||
|
clip?: components["schemas"]["ClipField"];
|
||||||
|
/**
|
||||||
|
* Vae Decoder
|
||||||
|
* @description Vae submodel
|
||||||
|
*/
|
||||||
|
vae_decoder?: components["schemas"]["VaeField"];
|
||||||
|
/**
|
||||||
|
* Vae Encoder
|
||||||
|
* @description Vae submodel
|
||||||
|
*/
|
||||||
|
vae_encoder?: components["schemas"]["VaeField"];
|
||||||
|
};
|
||||||
|
/**
|
||||||
|
* ONNXPromptInvocation
|
||||||
|
* @description A node to process inputs and produce outputs.
|
||||||
|
* May use dependency injection in __init__ to receive providers.
|
||||||
|
*/
|
||||||
|
ONNXPromptInvocation: {
|
||||||
|
/**
|
||||||
|
* Id
|
||||||
|
* @description The id of this node. Must be unique among all nodes.
|
||||||
|
*/
|
||||||
|
id: string;
|
||||||
|
/**
|
||||||
|
* Is Intermediate
|
||||||
|
* @description Whether or not this node is an intermediate node.
|
||||||
|
* @default false
|
||||||
|
*/
|
||||||
|
is_intermediate?: boolean;
|
||||||
|
/**
|
||||||
|
* Type
|
||||||
|
* @default prompt_onnx
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
type?: "prompt_onnx";
|
||||||
|
/**
|
||||||
|
* Prompt
|
||||||
|
* @description Prompt
|
||||||
|
* @default
|
||||||
|
*/
|
||||||
|
prompt?: string;
|
||||||
|
/**
|
||||||
|
* Clip
|
||||||
|
* @description Clip to use
|
||||||
|
*/
|
||||||
|
clip?: components["schemas"]["ClipField"];
|
||||||
|
};
|
||||||
|
/**
|
||||||
|
* ONNXSD1ModelLoaderInvocation
|
||||||
|
* @description Loading submodels of selected model.
|
||||||
|
*/
|
||||||
|
ONNXSD1ModelLoaderInvocation: {
|
||||||
|
/**
|
||||||
|
* Id
|
||||||
|
* @description The id of this node. Must be unique among all nodes.
|
||||||
|
*/
|
||||||
|
id: string;
|
||||||
|
/**
|
||||||
|
* Is Intermediate
|
||||||
|
* @description Whether or not this node is an intermediate node.
|
||||||
|
* @default false
|
||||||
|
*/
|
||||||
|
is_intermediate?: boolean;
|
||||||
|
/**
|
||||||
|
* Type
|
||||||
|
* @default sd1_model_loader_onnx
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
type?: "sd1_model_loader_onnx";
|
||||||
|
/**
|
||||||
|
* Model Name
|
||||||
|
* @description Model to load
|
||||||
|
* @default
|
||||||
|
*/
|
||||||
|
model_name?: string;
|
||||||
|
};
|
||||||
|
/** ONNXStableDiffusion1ModelConfig */
|
||||||
|
ONNXStableDiffusion1ModelConfig: {
|
||||||
|
/** Model Name */
|
||||||
|
model_name: string;
|
||||||
|
base_model: components["schemas"]["BaseModelType"];
|
||||||
|
/**
|
||||||
|
* Model Type
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
model_type: "onnx";
|
||||||
|
/** Path */
|
||||||
|
path: string;
|
||||||
|
/** Description */
|
||||||
|
description?: string;
|
||||||
|
/**
|
||||||
|
* Model Format
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
model_format: "onnx";
|
||||||
|
error?: components["schemas"]["ModelError"];
|
||||||
|
variant: components["schemas"]["ModelVariantType"];
|
||||||
|
};
|
||||||
|
/** ONNXStableDiffusion2ModelConfig */
|
||||||
|
ONNXStableDiffusion2ModelConfig: {
|
||||||
|
/** Model Name */
|
||||||
|
model_name: string;
|
||||||
|
base_model: components["schemas"]["BaseModelType"];
|
||||||
|
/**
|
||||||
|
* Model Type
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
model_type: "onnx";
|
||||||
|
/** Path */
|
||||||
|
path: string;
|
||||||
|
/** Description */
|
||||||
|
description?: string;
|
||||||
|
/**
|
||||||
|
* Model Format
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
model_format: "onnx";
|
||||||
|
error?: components["schemas"]["ModelError"];
|
||||||
|
variant: components["schemas"]["ModelVariantType"];
|
||||||
|
prediction_type: components["schemas"]["SchedulerPredictionType"];
|
||||||
|
/** Upcast Attention */
|
||||||
|
upcast_attention: boolean;
|
||||||
|
};
|
||||||
|
/**
|
||||||
|
* ONNXTextToLatentsInvocation
|
||||||
|
* @description Generates latents from conditionings.
|
||||||
|
*/
|
||||||
|
ONNXTextToLatentsInvocation: {
|
||||||
|
/**
|
||||||
|
* Id
|
||||||
|
* @description The id of this node. Must be unique among all nodes.
|
||||||
|
*/
|
||||||
|
id: string;
|
||||||
|
/**
|
||||||
|
* Is Intermediate
|
||||||
|
* @description Whether or not this node is an intermediate node.
|
||||||
|
* @default false
|
||||||
|
*/
|
||||||
|
is_intermediate?: boolean;
|
||||||
|
/**
|
||||||
|
* Type
|
||||||
|
* @default t2l_onnx
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
type?: "t2l_onnx";
|
||||||
|
/**
|
||||||
|
* Positive Conditioning
|
||||||
|
* @description Positive conditioning for generation
|
||||||
|
*/
|
||||||
|
positive_conditioning?: components["schemas"]["ConditioningField"];
|
||||||
|
/**
|
||||||
|
* Negative Conditioning
|
||||||
|
* @description Negative conditioning for generation
|
||||||
|
*/
|
||||||
|
negative_conditioning?: components["schemas"]["ConditioningField"];
|
||||||
|
/**
|
||||||
|
* Noise
|
||||||
|
* @description The noise to use
|
||||||
|
*/
|
||||||
|
noise?: components["schemas"]["LatentsField"];
|
||||||
|
/**
|
||||||
|
* Steps
|
||||||
|
* @description The number of steps to use to generate the image
|
||||||
|
* @default 10
|
||||||
|
*/
|
||||||
|
steps?: number;
|
||||||
|
/**
|
||||||
|
* Cfg Scale
|
||||||
|
* @description The Classifier-Free Guidance, higher values may result in a result closer to the prompt
|
||||||
|
* @default 7.5
|
||||||
|
*/
|
||||||
|
cfg_scale?: number | (number)[];
|
||||||
|
/**
|
||||||
|
* Scheduler
|
||||||
|
* @description The scheduler to use
|
||||||
|
* @default euler
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc";
|
||||||
|
/**
|
||||||
|
* Precision
|
||||||
|
* @description The precision to use when generating latents
|
||||||
|
* @default tensor(float16)
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
precision?: "tensor(bool)" | "tensor(int8)" | "tensor(uint8)" | "tensor(int16)" | "tensor(uint16)" | "tensor(int32)" | "tensor(uint32)" | "tensor(int64)" | "tensor(uint64)" | "tensor(float16)" | "tensor(float)" | "tensor(double)";
|
||||||
|
/**
|
||||||
|
* Unet
|
||||||
|
* @description UNet submodel
|
||||||
|
*/
|
||||||
|
unet?: components["schemas"]["UNetField"];
|
||||||
|
/**
|
||||||
|
* Control
|
||||||
|
* @description The control to use
|
||||||
|
*/
|
||||||
|
control?: components["schemas"]["ControlField"] | (components["schemas"]["ControlField"])[];
|
||||||
|
};
|
||||||
/**
|
/**
|
||||||
* OffsetPaginatedResults[BoardDTO]
|
* OffsetPaginatedResults[BoardDTO]
|
||||||
* @description Offset-paginated results
|
* @description Offset-paginated results
|
||||||
@ -3830,6 +4087,49 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
total: number;
|
total: number;
|
||||||
};
|
};
|
||||||
|
/**
|
||||||
|
* OnnxModelField
|
||||||
|
* @description Onnx model field
|
||||||
|
*/
|
||||||
|
OnnxModelField: {
|
||||||
|
/**
|
||||||
|
* Model Name
|
||||||
|
* @description Name of the model
|
||||||
|
*/
|
||||||
|
model_name: string;
|
||||||
|
/** @description Base model */
|
||||||
|
base_model: components["schemas"]["BaseModelType"];
|
||||||
|
/** @description Model Type */
|
||||||
|
model_type: components["schemas"]["ModelType"];
|
||||||
|
};
|
||||||
|
/**
|
||||||
|
* OnnxModelLoaderInvocation
|
||||||
|
* @description Loads a main model, outputting its submodels.
|
||||||
|
*/
|
||||||
|
OnnxModelLoaderInvocation: {
|
||||||
|
/**
|
||||||
|
* Id
|
||||||
|
* @description The id of this node. Must be unique among all nodes.
|
||||||
|
*/
|
||||||
|
id: string;
|
||||||
|
/**
|
||||||
|
* Is Intermediate
|
||||||
|
* @description Whether or not this node is an intermediate node.
|
||||||
|
* @default false
|
||||||
|
*/
|
||||||
|
is_intermediate?: boolean;
|
||||||
|
/**
|
||||||
|
* Type
|
||||||
|
* @default onnx_model_loader
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
type?: "onnx_model_loader";
|
||||||
|
/**
|
||||||
|
* Model
|
||||||
|
* @description The model to load
|
||||||
|
*/
|
||||||
|
model: components["schemas"]["OnnxModelField"];
|
||||||
|
};
|
||||||
/**
|
/**
|
||||||
* OpenposeImageProcessorInvocation
|
* OpenposeImageProcessorInvocation
|
||||||
* @description Applies Openpose processing to image
|
* @description Applies Openpose processing to image
|
||||||
@ -4963,6 +5263,12 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
antialias?: boolean;
|
antialias?: boolean;
|
||||||
};
|
};
|
||||||
|
/**
|
||||||
|
* SchedulerPredictionType
|
||||||
|
* @description An enumeration.
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
SchedulerPredictionType: "epsilon" | "v_prediction" | "sample";
|
||||||
/**
|
/**
|
||||||
* SegmentAnythingProcessorInvocation
|
* SegmentAnythingProcessorInvocation
|
||||||
* @description Applies segment anything processing to image
|
* @description Applies segment anything processing to image
|
||||||
@ -5273,7 +5579,7 @@ export type components = {
|
|||||||
* @description An enumeration.
|
* @description An enumeration.
|
||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
SubModelType: "unet" | "text_encoder" | "text_encoder_2" | "tokenizer" | "tokenizer_2" | "vae" | "scheduler" | "safety_checker";
|
SubModelType: "unet" | "text_encoder" | "text_encoder_2" | "tokenizer" | "tokenizer_2" | "vae" | "vae_decoder" | "vae_encoder" | "scheduler" | "safety_checker";
|
||||||
/**
|
/**
|
||||||
* SubtractInvocation
|
* SubtractInvocation
|
||||||
* @description Subtracts two numbers
|
* @description Subtracts two numbers
|
||||||
@ -5586,29 +5892,35 @@ export type components = {
|
|||||||
image?: components["schemas"]["ImageField"];
|
image?: components["schemas"]["ImageField"];
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* StableDiffusion2ModelFormat
|
* StableDiffusionOnnxModelFormat
|
||||||
* @description An enumeration.
|
* @description An enumeration.
|
||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
StableDiffusionOnnxModelFormat: "olive" | "onnx";
|
||||||
/**
|
/**
|
||||||
* StableDiffusionXLModelFormat
|
* StableDiffusionXLModelFormat
|
||||||
* @description An enumeration.
|
* @description An enumeration.
|
||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
|
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
|
||||||
/**
|
|
||||||
* StableDiffusion1ModelFormat
|
|
||||||
* @description An enumeration.
|
|
||||||
* @enum {string}
|
|
||||||
*/
|
|
||||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
|
||||||
/**
|
/**
|
||||||
* ControlNetModelFormat
|
* ControlNetModelFormat
|
||||||
* @description An enumeration.
|
* @description An enumeration.
|
||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
ControlNetModelFormat: "checkpoint" | "diffusers";
|
ControlNetModelFormat: "checkpoint" | "diffusers";
|
||||||
|
/**
|
||||||
|
* StableDiffusion2ModelFormat
|
||||||
|
* @description An enumeration.
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||||
|
/**
|
||||||
|
* StableDiffusion1ModelFormat
|
||||||
|
* @description An enumeration.
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||||
};
|
};
|
||||||
responses: never;
|
responses: never;
|
||||||
parameters: never;
|
parameters: never;
|
||||||
@ -5719,7 +6031,7 @@ export type operations = {
|
|||||||
};
|
};
|
||||||
requestBody: {
|
requestBody: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
"application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
responses: {
|
responses: {
|
||||||
@ -5756,7 +6068,7 @@ export type operations = {
|
|||||||
};
|
};
|
||||||
requestBody: {
|
requestBody: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
"application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
responses: {
|
responses: {
|
||||||
@ -6020,14 +6332,14 @@ export type operations = {
|
|||||||
};
|
};
|
||||||
requestBody: {
|
requestBody: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
"application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
responses: {
|
responses: {
|
||||||
/** @description The model was updated successfully */
|
/** @description The model was updated successfully */
|
||||||
200: {
|
200: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
"application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
/** @description Bad request */
|
/** @description Bad request */
|
||||||
@ -6058,7 +6370,7 @@ export type operations = {
|
|||||||
/** @description The model imported successfully */
|
/** @description The model imported successfully */
|
||||||
201: {
|
201: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
"application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
/** @description The model could not be found */
|
/** @description The model could not be found */
|
||||||
@ -6084,14 +6396,14 @@ export type operations = {
|
|||||||
add_model: {
|
add_model: {
|
||||||
requestBody: {
|
requestBody: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
"application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
responses: {
|
responses: {
|
||||||
/** @description The model added successfully */
|
/** @description The model added successfully */
|
||||||
201: {
|
201: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
"application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
/** @description The model could not be found */
|
/** @description The model could not be found */
|
||||||
@ -6131,7 +6443,7 @@ export type operations = {
|
|||||||
/** @description Model converted successfully */
|
/** @description Model converted successfully */
|
||||||
200: {
|
200: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
"application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
/** @description Bad request */
|
/** @description Bad request */
|
||||||
@ -6220,7 +6532,7 @@ export type operations = {
|
|||||||
/** @description Model converted successfully */
|
/** @description Model converted successfully */
|
||||||
200: {
|
200: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
"application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
/** @description Incompatible models */
|
/** @description Incompatible models */
|
||||||
|
@ -32,6 +32,7 @@ export type ModelType = components['schemas']['ModelType'];
|
|||||||
export type SubModelType = components['schemas']['SubModelType'];
|
export type SubModelType = components['schemas']['SubModelType'];
|
||||||
export type BaseModelType = components['schemas']['BaseModelType'];
|
export type BaseModelType = components['schemas']['BaseModelType'];
|
||||||
export type MainModelField = components['schemas']['MainModelField'];
|
export type MainModelField = components['schemas']['MainModelField'];
|
||||||
|
export type OnnxModelField = components['schemas']['OnnxModelField'];
|
||||||
export type VAEModelField = components['schemas']['VAEModelField'];
|
export type VAEModelField = components['schemas']['VAEModelField'];
|
||||||
export type LoRAModelField = components['schemas']['LoRAModelField'];
|
export type LoRAModelField = components['schemas']['LoRAModelField'];
|
||||||
export type ControlNetModelField =
|
export type ControlNetModelField =
|
||||||
@ -58,6 +59,8 @@ export type DiffusersModelConfig =
|
|||||||
export type CheckpointModelConfig =
|
export type CheckpointModelConfig =
|
||||||
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
|
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
|
||||||
| components['schemas']['StableDiffusion2ModelCheckpointConfig']
|
| components['schemas']['StableDiffusion2ModelCheckpointConfig']
|
||||||
|
| components['schemas']['StableDiffusion2ModelDiffusersConfig'];
|
||||||
|
export type OnnxModelConfig = components['schemas']['ONNXStableDiffusion1ModelConfig']
|
||||||
| components['schemas']['StableDiffusionXLModelCheckpointConfig'];
|
| components['schemas']['StableDiffusionXLModelCheckpointConfig'];
|
||||||
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
|
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
|
||||||
export type AnyModelConfig =
|
export type AnyModelConfig =
|
||||||
@ -65,7 +68,8 @@ export type AnyModelConfig =
|
|||||||
| VaeModelConfig
|
| VaeModelConfig
|
||||||
| ControlNetModelConfig
|
| ControlNetModelConfig
|
||||||
| TextualInversionModelConfig
|
| TextualInversionModelConfig
|
||||||
| MainModelConfig;
|
| MainModelConfig
|
||||||
|
| OnnxModelConfig;
|
||||||
|
|
||||||
export type MergeModelConfig = components['schemas']['Body_merge_models'];
|
export type MergeModelConfig = components['schemas']['Body_merge_models'];
|
||||||
export type ConvertModelConfig = components['schemas']['Body_convert_model'];
|
export type ConvertModelConfig = components['schemas']['Body_convert_model'];
|
||||||
@ -127,6 +131,9 @@ export type ImageCollectionInvocation = TypeReq<
|
|||||||
export type MainModelLoaderInvocation = TypeReq<
|
export type MainModelLoaderInvocation = TypeReq<
|
||||||
components['schemas']['MainModelLoaderInvocation']
|
components['schemas']['MainModelLoaderInvocation']
|
||||||
>;
|
>;
|
||||||
|
export type OnnxModelLoaderInvocation = TypeReq<
|
||||||
|
components['schemas']['OnnxModelLoaderInvocation']
|
||||||
|
>;
|
||||||
export type LoraLoaderInvocation = TypeReq<
|
export type LoraLoaderInvocation = TypeReq<
|
||||||
components['schemas']['LoraLoaderInvocation']
|
components['schemas']['LoraLoaderInvocation']
|
||||||
>;
|
>;
|
||||||
|
@ -61,6 +61,7 @@ dependencies = [
|
|||||||
"numpy",
|
"numpy",
|
||||||
"npyscreen",
|
"npyscreen",
|
||||||
"omegaconf",
|
"omegaconf",
|
||||||
|
"onnx",
|
||||||
"opencv-python",
|
"opencv-python",
|
||||||
"pydantic==1.*",
|
"pydantic==1.*",
|
||||||
"picklescan",
|
"picklescan",
|
||||||
@ -103,6 +104,15 @@ dependencies = [
|
|||||||
"xformers~=0.0.19; sys_platform!='darwin'",
|
"xformers~=0.0.19; sys_platform!='darwin'",
|
||||||
"triton; sys_platform=='linux'",
|
"triton; sys_platform=='linux'",
|
||||||
]
|
]
|
||||||
|
"onnx" = [
|
||||||
|
"onnxruntime",
|
||||||
|
]
|
||||||
|
"onnx-cuda" = [
|
||||||
|
"onnxruntime-gpu",
|
||||||
|
]
|
||||||
|
"onnx-directml" = [
|
||||||
|
"onnxruntime-directml",
|
||||||
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|
||||||
@ -180,4 +190,4 @@ output = "coverage/index.xml"
|
|||||||
max-line-length = 120
|
max-line-length = 120
|
||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 120
|
line-length = 120
|
||||||
|
Loading…
Reference in New Issue
Block a user