Merge branch 'main' into minimapcontrol

This commit is contained in:
blessedcoolant 2023-07-18 16:25:49 +12:00 committed by GitHub
commit 6e40b543cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 2735 additions and 79 deletions

View File

@ -125,7 +125,7 @@ To make a workspace:
Should look something like this:
```json
```jsonc
{
// I like to have all InvokeAI-related folders in my workspace
"folders": [
@ -164,7 +164,7 @@ to set up your `launch.json` and try it out.
Now we can create the InvokeAI debugging configs:
```json
```jsonc
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.

View File

@ -14,6 +14,7 @@ from invokeai.backend.model_management.models import (
OPENAPI_MODEL_CONFIGS,
SchedulerPredictionType,
ModelNotFoundException,
InvalidModelException,
)
from invokeai.backend.model_management import MergeInterpolationMethod
@ -36,11 +37,16 @@ class ModelsList(BaseModel):
responses={200: {"model": ModelsList }},
)
async def list_models(
base_model: Optional[BaseModelType] = Query(default=None, description="Base model"),
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
) -> ModelsList:
"""Gets a list of models"""
models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)
if base_models and len(base_models)>0:
models_raw = list()
for base_model in base_models:
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
else:
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
models = parse_obj_as(ModelsList, { "models": models_raw })
return models
@ -123,6 +129,7 @@ async def update_model(
responses= {
201: {"description" : "The model imported successfully"},
404: {"description" : "The model could not be found"},
415: {"description" : "Unrecognized file/folder format"},
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description" : "There is already a model corresponding to this path or repo_id"},
},
@ -149,7 +156,7 @@ async def import_model(
if not info:
logger.error("Import failed")
raise HTTPException(status_code=424)
raise HTTPException(status_code=415)
logger.info(f'Successfully imported {location}, got {info}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
@ -162,6 +169,9 @@ async def import_model(
except ModelNotFoundException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))

View File

@ -1,8 +1,8 @@
from typing import Literal, Optional, Union, List
from typing import Literal, Optional, Union, List, Annotated
from pydantic import BaseModel, Field
import re
import torch
from compel import Compel
from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import (Blend, Conjunction,
CrossAttentionControlSubstitute,
FlattenedPrompt, Fragment)
@ -14,6 +14,7 @@ from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .model import ClipField
from dataclasses import dataclass
class ConditioningField(BaseModel):
@ -23,6 +24,34 @@ class ConditioningField(BaseModel):
class Config:
schema_extra = {"required": ["conditioning_name"]}
@dataclass
class BasicConditioningInfo:
#type: Literal["basic_conditioning"] = "basic_conditioning"
embeds: torch.Tensor
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
# weight: float
# mode: ConditioningAlgo
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
#type: Literal["sdxl_conditioning"] = "sdxl_conditioning"
pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor
ConditioningInfoType = Annotated[
Union[BasicConditioningInfo, SDXLConditioningInfo],
Field(discriminator="type")
]
@dataclass
class ConditioningFieldData:
conditionings: List[Union[BasicConditioningInfo, SDXLConditioningInfo]]
#unconditioned: Optional[torch.Tensor]
#class ConditioningAlgo(str, Enum):
# Compose = "compose"
# ComposeEx = "compose_ex"
# PerpNeg = "perp_neg"
class CompelOutput(BaseInvocationOutput):
"""Compel parser output"""
@ -119,10 +148,17 @@ class CompelInvocation(BaseInvocation):
cross_attention_control_args=options.get(
"cross_attention_control", None),)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
conditioning_data = ConditioningFieldData(
conditionings=[
BasicConditioningInfo(
embeds=c,
extra_conditioning=ec,
)
]
)
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.save(conditioning_name, (c, ec))
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data)
return CompelOutput(
conditioning=ConditioningField(
@ -130,6 +166,389 @@ class CompelInvocation(BaseInvocation):
),
)
class SDXLPromptInvocationBase:
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(),
)
text_encoder_info = context.services.model_manager.get_model(
**clip_field.text_encoder.dict(),
)
def _lora_loader():
for lora in clip_field.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
name = trigger[1:-1]
try:
ti_list.append(
context.services.model_manager.get_model(
model_name=name,
base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion,
).context.model
)
except ModelNotFoundException:
# print(e)
#import traceback
#print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found")
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\
text_encoder_info as text_encoder:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(
text_input_ids.to(text_encoder.device),
output_hidden_states=True,
)
if get_pooled:
c_pooled = prompt_embeds[0]
else:
c_pooled = None
c = prompt_embeds.hidden_states[-2]
del tokenizer
del text_encoder
del tokenizer_info
del text_encoder_info
return c, c_pooled, None
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(),
)
text_encoder_info = context.services.model_manager.get_model(
**clip_field.text_encoder.dict(),
)
def _lora_loader():
for lora in clip_field.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
name = trigger[1:-1]
try:
ti_list.append(
context.services.model_manager.get_model(
model_name=name,
base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion,
).context.model
)
except ModelNotFoundException:
# print(e)
#import traceback
#print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found")
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\
text_encoder_info as text_encoder:
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=True, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=True,
)
conjunction = Compel.parse_prompt_string(prompt)
if context.services.configuration.log_tokenization:
# TODO: better logging for and syntax
for prompt_obj in conjunction.prompts:
log_tokenization_for_prompt_object(prompt_obj, tokenizer)
# TODO: ask for optimizations? to not run text_encoder twice
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
if get_pooled:
c_pooled = compel.conditioning_provider.get_pooled_embeddings([prompt])
else:
c_pooled = None
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None),
)
del tokenizer
del text_encoder
del tokenizer_info
del text_encoder_info
return c, c_pooled, ec
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
type: Literal["sdxl_compel_prompt"] = "sdxl_compel_prompt"
prompt: str = Field(default="", description="Prompt")
style: str = Field(default="", description="Style prompt")
original_width: int = Field(1024, description="")
original_height: int = Field(1024, description="")
crop_top: int = Field(0, description="")
crop_left: int = Field(0, description="")
target_width: int = Field(1024, description="")
target_height: int = Field(1024, description="")
clip1: ClipField = Field(None, description="Clip to use")
clip2: ClipField = Field(None, description="Clip to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Prompt (Compel)",
"tags": ["prompt", "compel"],
"type_hints": {
"model": "model"
}
},
}
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip1, self.prompt, False)
if self.style.strip() == "":
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True)
else:
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
target_size = (self.target_height, self.target_width)
add_time_ids = torch.tensor([
original_size + crop_coords + target_size
])
conditioning_data = ConditioningFieldData(
conditionings=[
SDXLConditioningInfo(
embeds=torch.cat([c1, c2], dim=-1),
pooled_embeds=c2_pooled,
add_time_ids=add_time_ids,
extra_conditioning=ec1,
)
]
)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data)
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt"
style: str = Field(default="", description="Style prompt") # TODO: ?
original_width: int = Field(1024, description="")
original_height: int = Field(1024, description="")
crop_top: int = Field(0, description="")
crop_left: int = Field(0, description="")
aesthetic_score: float = Field(6.0, description="")
clip2: ClipField = Field(None, description="Clip to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Refiner Prompt (Compel)",
"tags": ["prompt", "compel"],
"type_hints": {
"model": "model"
}
},
}
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
add_time_ids = torch.tensor([
original_size + crop_coords + (self.aesthetic_score,)
])
conditioning_data = ConditioningFieldData(
conditionings=[
SDXLConditioningInfo(
embeds=c2,
pooled_embeds=c2_pooled,
add_time_ids=add_time_ids,
extra_conditioning=ec2, # or None
)
]
)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data)
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Pass unmodified prompt to conditioning without compel processing."""
type: Literal["sdxl_raw_prompt"] = "sdxl_raw_prompt"
prompt: str = Field(default="", description="Prompt")
style: str = Field(default="", description="Style prompt")
original_width: int = Field(1024, description="")
original_height: int = Field(1024, description="")
crop_top: int = Field(0, description="")
crop_left: int = Field(0, description="")
target_width: int = Field(1024, description="")
target_height: int = Field(1024, description="")
clip1: ClipField = Field(None, description="Clip to use")
clip2: ClipField = Field(None, description="Clip to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Prompt (Raw)",
"tags": ["prompt", "compel"],
"type_hints": {
"model": "model"
}
},
}
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip1, self.prompt, False)
if self.style.strip() == "":
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True)
else:
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
target_size = (self.target_height, self.target_width)
add_time_ids = torch.tensor([
original_size + crop_coords + target_size
])
conditioning_data = ConditioningFieldData(
conditionings=[
SDXLConditioningInfo(
embeds=torch.cat([c1, c2], dim=-1),
pooled_embeds=c2_pooled,
add_time_ids=add_time_ids,
extra_conditioning=ec1,
)
]
)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data)
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
type: Literal["sdxl_refiner_raw_prompt"] = "sdxl_refiner_raw_prompt"
style: str = Field(default="", description="Style prompt") # TODO: ?
original_width: int = Field(1024, description="")
original_height: int = Field(1024, description="")
crop_top: int = Field(0, description="")
crop_left: int = Field(0, description="")
aesthetic_score: float = Field(6.0, description="")
clip2: ClipField = Field(None, description="Clip to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Refiner Prompt (Raw)",
"tags": ["prompt", "compel"],
"type_hints": {
"model": "model"
}
},
}
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
add_time_ids = torch.tensor([
original_size + crop_coords + (self.aesthetic_score,)
])
conditioning_data = ConditioningFieldData(
conditionings=[
SDXLConditioningInfo(
embeds=c2,
pooled_embeds=c2_pooled,
add_time_ids=add_time_ids,
extra_conditioning=ec2, # or None
)
]
)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data)
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
class ClipSkipInvocationOutput(BaseInvocationOutput):
"""Clip skip node output"""
type: Literal["clip_skip_output"] = "clip_skip_output"

View File

@ -31,6 +31,13 @@ from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations"""
@ -161,12 +168,12 @@ class TextToLatentsInvocation(BaseInvocation):
context: InvocationContext,
scheduler,
) -> ConditioningData:
c, extra_conditioning_info = context.services.latents.get(
self.positive_conditioning.conditioning_name
)
uc, _ = context.services.latents.get(
self.negative_conditioning.conditioning_name
)
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].embeds
extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].embeds
conditioning_data = ConditioningData(
unconditioned_embeddings=uc,
@ -476,8 +483,8 @@ class LatentsToImageInvocation(BaseInvocation):
tiled: bool = Field(
default=False,
description="Decode latents by overlaping tiles(less memory consumption)")
fp32: bool = Field(False, description="Decode in full precision")
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
# Schema customisation
class Config(InvocationConfig):
@ -496,6 +503,31 @@ class LatentsToImageInvocation(BaseInvocation):
)
with vae_info as vae:
if self.fp32:
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance(
vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
vae.post_quant_conv.to(latents.dtype)
vae.decoder.conv_in.to(latents.dtype)
vae.decoder.mid_block.to(latents.dtype)
else:
latents = latents.float()
else:
vae.to(dtype=torch.float16)
latents = latents.half()
if self.tiled or context.services.configuration.tiled_decode:
vae.enable_tiling()
else:
@ -619,6 +651,8 @@ class ImageToLatentsInvocation(BaseInvocation):
tiled: bool = Field(
default=False,
description="Encode latents by overlaping tiles(less memory consumption)")
fp32: bool = Field(False, description="Decode in full precision")
# Schema customisation
class Config(InvocationConfig):
@ -645,6 +679,32 @@ class ImageToLatentsInvocation(BaseInvocation):
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
with vae_info as vae:
orig_dtype = vae.dtype
if self.fp32:
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance(
vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
vae.post_quant_conv.to(orig_dtype)
vae.decoder.conv_in.to(orig_dtype)
vae.decoder.mid_block.to(orig_dtype)
#else:
# latents = latents.float()
else:
vae.to(dtype=torch.float16)
#latents = latents.half()
if self.tiled:
vae.enable_tiling()
else:
@ -659,6 +719,7 @@ class ImageToLatentsInvocation(BaseInvocation):
) # FIXME: uses torch.randn. make reproducible!
latents = 0.18215 * latents
latents = latents.to(dtype=orig_dtype)
name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, latents)

View File

@ -33,7 +33,6 @@ class ClipField(BaseModel):
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
class VaeField(BaseModel):
# TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel")
@ -50,7 +49,6 @@ class ModelLoaderOutput(BaseInvocationOutput):
vae: VaeField = Field(default=None, description="Vae submodel")
# fmt: on
class MainModelField(BaseModel):
"""Main model field"""
@ -64,7 +62,6 @@ class LoRAModelField(BaseModel):
model_name: str = Field(description="Name of the LoRA model")
base_model: BaseModelType = Field(description="Base model")
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
@ -157,6 +154,22 @@ class MainModelLoaderInvocation(BaseInvocation):
loras=[],
skipped_layers=0,
),
clip2=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
@ -167,7 +180,7 @@ class MainModelLoaderInvocation(BaseInvocation):
),
)
class LoraLoaderOutput(BaseInvocationOutput):
"""Model loader output"""

View File

@ -0,0 +1,658 @@
import torch
import inspect
from tqdm import tqdm
from typing import List, Literal, Optional, Union
from pydantic import Field, validator
from ...backend.model_management import ModelType, SubModelType
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo
from .compel import ConditioningField
from .latent import LatentsField, SAMPLER_NAME_VALUES, LatentsOutput, get_scheduler, build_latents_output
class SDXLModelLoaderOutput(BaseInvocationOutput):
"""SDXL base model loader output"""
# fmt: off
type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output"
unet: UNetField = Field(default=None, description="UNet submodel")
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
vae: VaeField = Field(default=None, description="Vae submodel")
# fmt: on
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
"""SDXL refiner model loader output"""
# fmt: off
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
unet: UNetField = Field(default=None, description="UNet submodel")
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
vae: VaeField = Field(default=None, description="Vae submodel")
# fmt: on
#fmt: on
class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels."""
type: Literal["sdxl_model_loader"] = "sdxl_model_loader"
model: MainModelField = Field(description="The model to load")
# TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Model Loader",
"tags": ["model", "loader", "sdxl"],
"type_hints": {"model": "model"},
},
}
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Main
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,
):
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
return SDXLModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
),
clip2=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Vae,
),
),
)
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels."""
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
model: MainModelField = Field(description="The model to load")
# TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Refiner Model Loader",
"tags": ["model", "loader", "sdxl_refiner"],
"type_hints": {"model": "model"},
},
}
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Main
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,
):
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
return SDXLRefinerModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip2=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Vae,
),
),
)
# Text to image
class SDXLTextToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings."""
type: Literal["t2l_sdxl"] = "t2l_sdxl"
# Inputs
# fmt: off
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
unet: UNetField = Field(default=None, description="UNet submodel")
denoising_end: float = Field(default=1.0, gt=0, le=1, description="")
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on
@validator("cfg_scale")
def ge_one(cls, v):
"""validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError('cfg_scale must be greater than 1')
else:
if v < 1:
raise ValueError('cfg_scale must be greater than 1')
return v
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents"],
"type_hints": {
"model": "model",
# "cfg_scale": "float",
"cfg_scale": "number"
}
},
}
# based on
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.noise.latents_name)
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
prompt_embeds = positive_cond_data.conditionings[0].embeds
pooled_prompt_embeds = positive_cond_data.conditionings[0].pooled_embeds
add_time_ids = positive_cond_data.conditionings[0].add_time_ids
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
negative_prompt_embeds = negative_cond_data.conditionings[0].embeds
negative_pooled_prompt_embeds = negative_cond_data.conditionings[0].pooled_embeds
add_neg_time_ids = negative_cond_data.conditionings[0].add_time_ids
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
num_inference_steps = self.steps
scheduler.set_timesteps(num_inference_steps)
timesteps = scheduler.timesteps
latents = latents * scheduler.init_noise_sigma
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict()
)
do_classifier_free_guidance = True
cross_attention_kwargs = None
with unet_info as unet:
extra_step_kwargs = dict()
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
extra_step_kwargs.update(
eta=0.0,
)
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
extra_step_kwargs.update(
generator=torch.Generator(device=unet.device).manual_seed(0),
)
num_warmup_steps = len(timesteps) - self.steps * scheduler.order
# apply denoising_end
skipped_final_steps = int(round((1 - self.denoising_end) * self.steps))
num_inference_steps = num_inference_steps - skipped_final_steps
timesteps = timesteps[: num_warmup_steps + scheduler.order * num_inference_steps]
if not context.services.configuration.sequential_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
add_text_embeds = add_text_embeds.to(device=unet.device, dtype=unet.dtype)
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
latents = latents.to(device=unet.device, dtype=unet.dtype)
with tqdm(total=self.steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
noise_pred = unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
#del noise_pred_uncond
#del noise_pred_text
#if do_classifier_free_guidance and guidance_rescale > 0.0:
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
else:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
negative_prompt_embeds = negative_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
add_neg_time_ids = add_neg_time_ids.to(device=unet.device, dtype=unet.dtype)
pooled_prompt_embeds = pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
latents = latents.to(device=unet.device, dtype=unet.dtype)
with tqdm(total=self.steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
#latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = scheduler.scale_model_input(latents, t)
#import gc
#gc.collect()
#torch.cuda.empty_cache()
# predict the noise residual
added_cond_kwargs = {"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_neg_time_ids}
noise_pred_uncond = unet(
latent_model_input,
t,
encoder_hidden_states=negative_prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
noise_pred_text = unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
#del noise_pred_text
#del noise_pred_uncond
#import gc
#gc.collect()
#torch.cuda.empty_cache()
#if do_classifier_free_guidance and guidance_rescale > 0.0:
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
#del noise_pred
#import gc
#gc.collect()
#torch.cuda.empty_cache()
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
#################
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents)
class SDXLLatentsToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings."""
type: Literal["l2l_sdxl"] = "l2l_sdxl"
# Inputs
# fmt: off
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
unet: UNetField = Field(default=None, description="UNet submodel")
latents: Optional[LatentsField] = Field(description="Initial latents")
denoising_start: float = Field(default=0.0, ge=0, lt=1, description="")
denoising_end: float = Field(default=1.0, gt=0, le=1, description="")
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on
@validator("cfg_scale")
def ge_one(cls, v):
"""validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError('cfg_scale must be greater than 1')
else:
if v < 1:
raise ValueError('cfg_scale must be greater than 1')
return v
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents"],
"type_hints": {
"model": "model",
# "cfg_scale": "float",
"cfg_scale": "number"
}
},
}
# based on
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
prompt_embeds = positive_cond_data.conditionings[0].embeds
pooled_prompt_embeds = positive_cond_data.conditionings[0].pooled_embeds
add_time_ids = positive_cond_data.conditionings[0].add_time_ids
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
negative_prompt_embeds = negative_cond_data.conditionings[0].embeds
negative_pooled_prompt_embeds = negative_cond_data.conditionings[0].pooled_embeds
add_neg_time_ids = negative_cond_data.conditionings[0].add_time_ids
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
# apply denoising_start
num_inference_steps = self.steps
scheduler.set_timesteps(num_inference_steps)
t_start = int(round(self.denoising_start * num_inference_steps))
timesteps = scheduler.timesteps[t_start * scheduler.order:]
num_inference_steps = num_inference_steps - t_start
# apply noise(if provided)
if self.noise is not None:
noise = context.services.latents.get(self.noise.latents_name)
latents = scheduler.add_noise(latents, noise, timesteps[:1])
del noise
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict()
)
do_classifier_free_guidance = True
cross_attention_kwargs = None
with unet_info as unet:
# apply scheduler extra args
extra_step_kwargs = dict()
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
extra_step_kwargs.update(
eta=0.0,
)
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
extra_step_kwargs.update(
generator=torch.Generator(device=unet.device).manual_seed(0),
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0)
# apply denoising_end
skipped_final_steps = int(round((1 - self.denoising_end) * self.steps))
num_inference_steps = num_inference_steps - skipped_final_steps
timesteps = timesteps[: num_warmup_steps + scheduler.order * num_inference_steps]
if not context.services.configuration.sequential_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
add_text_embeds = add_text_embeds.to(device=unet.device, dtype=unet.dtype)
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
latents = latents.to(device=unet.device, dtype=unet.dtype)
with tqdm(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
noise_pred = unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
#del noise_pred_uncond
#del noise_pred_text
#if do_classifier_free_guidance and guidance_rescale > 0.0:
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
else:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
negative_prompt_embeds = negative_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
add_neg_time_ids = add_neg_time_ids.to(device=unet.device, dtype=unet.dtype)
pooled_prompt_embeds = pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
latents = latents.to(device=unet.device, dtype=unet.dtype)
with tqdm(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
#latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = scheduler.scale_model_input(latents, t)
#import gc
#gc.collect()
#torch.cuda.empty_cache()
# predict the noise residual
added_cond_kwargs = {"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_time_ids}
noise_pred_uncond = unet(
latent_model_input,
t,
encoder_hidden_states=negative_prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
noise_pred_text = unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
#del noise_pred_text
#del noise_pred_uncond
#import gc
#gc.collect()
#torch.cuda.empty_cache()
#if do_classifier_free_guidance and guidance_rescale > 0.0:
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
#del noise_pred
#import gc
#gc.collect()
#torch.cuda.empty_cache()
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
#################
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents)

View File

@ -212,7 +212,7 @@ class ModelInstall(object):
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
]
):
models_installed.update(self._install_path(path))
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
# recursive scan
elif path.is_dir():

View File

@ -12,6 +12,7 @@ from picklescan.scanner import scan_file_path
from .models import (
BaseModelType, ModelType, ModelVariantType,
SchedulerPredictionType, SilenceWarnings,
InvalidModelException
)
from .models.base import read_checkpoint_meta
@ -38,6 +39,8 @@ class ModelProbe(object):
CLASS2TYPE = {
'StableDiffusionPipeline' : ModelType.Main,
'StableDiffusionXLPipeline' : ModelType.Main,
'StableDiffusionXLImg2ImgPipeline' : ModelType.Main,
'AutoencoderKL' : ModelType.Vae,
'ControlNetModel' : ModelType.ControlNet,
}
@ -59,7 +62,7 @@ class ModelProbe(object):
elif isinstance(model,(dict,ModelMixin,ConfigMixin)):
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
else:
raise ValueError("model parameter {model} is neither a Path, nor a model")
raise InvalidModelException("model parameter {model} is neither a Path, nor a model")
@classmethod
def probe(cls,
@ -99,9 +102,10 @@ class ModelProbe(object):
upcast_attention = (base_type==BaseModelType.StableDiffusion2 \
and prediction_type==SchedulerPredictionType.VPrediction),
format = format,
image_size = 768 if (base_type==BaseModelType.StableDiffusion2 \
and prediction_type==SchedulerPredictionType.VPrediction \
) else 512,
image_size = 1024 if (base_type in {BaseModelType.StableDiffusionXL,BaseModelType.StableDiffusionXLRefiner}) else \
768 if (base_type==BaseModelType.StableDiffusion2 \
and prediction_type==SchedulerPredictionType.VPrediction ) else \
512
)
except Exception:
raise
@ -138,7 +142,7 @@ class ModelProbe(object):
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
return ModelType.TextualInversion
raise ValueError(f"Unable to determine model type for {model_path}")
raise InvalidModelException(f"Unable to determine model type for {model_path}")
@classmethod
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
@ -168,7 +172,7 @@ class ModelProbe(object):
return type
# give up
raise ValueError(f"Unable to determine model type for {folder_path}")
raise InvalidModelException(f"Unable to determine model type for {folder_path}")
@classmethod
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
@ -237,7 +241,7 @@ class CheckpointProbeBase(ProbeBase):
elif in_channels == 4:
return ModelVariantType.Normal
else:
raise ValueError(f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}")
raise InvalidModelException(f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}")
class PipelineCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
@ -248,7 +252,10 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
return BaseModelType.StableDiffusion1
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
raise ValueError("Cannot determine base type")
# TODO: Verify that this is correct! Need an XL checkpoint file for this.
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
return BaseModelType.StableDiffusionXL
raise InvalidModelException("Cannot determine base type")
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
type = self.get_base_type()
@ -329,7 +336,7 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
return BaseModelType.StableDiffusion2
elif self.checkpoint_path and self.helper:
return self.helper(self.checkpoint_path)
raise ValueError("Unable to determine base type for {self.checkpoint_path}")
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
########################################################
# classes for probing folders
@ -360,8 +367,12 @@ class PipelineFolderProbe(FolderProbeBase):
return BaseModelType.StableDiffusion1
elif unet_conf['cross_attention_dim'] == 1024:
return BaseModelType.StableDiffusion2
elif unet_conf['cross_attention_dim'] == 1280:
return BaseModelType.StableDiffusionXLRefiner
elif unet_conf['cross_attention_dim'] == 2048:
return BaseModelType.StableDiffusionXL
else:
raise ValueError(f'Unknown base model for {self.folder_path}')
raise InvalidModelException(f'Unknown base model for {self.folder_path}')
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
if self.model:
@ -418,7 +429,7 @@ class ControlNetFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
config_file = self.folder_path / 'config.json'
if not config_file.exists():
raise ValueError(f"Cannot determine base type for {self.folder_path}")
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
with open(config_file,'r') as file:
config = json.load(file)
# no obvious way to distinguish between sd2-base and sd2-768
@ -435,7 +446,7 @@ class LoRAFolderProbe(FolderProbeBase):
model_file = base_file
break
if not model_file:
raise ValueError('Unknown LoRA format encountered')
raise InvalidModelException('Unknown LoRA format encountered')
return LoRACheckpointProbe(model_file,None).get_base_type()
############## register probe classes ######

View File

@ -4,6 +4,7 @@ from pydantic import BaseModel
from typing import Literal, get_origin
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException, InvalidModelException
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .sdxl import StableDiffusionXLModel
from .vae import VaeModel
from .lora import LoRAModel
from .controlnet import ControlNetModel # TODO:
@ -24,6 +25,22 @@ MODEL_CLASSES = {
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModelType.StableDiffusionXL: {
ModelType.Main: StableDiffusionXLModel,
ModelType.Vae: VaeModel,
# will not work until support written
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModelType.StableDiffusionXLRefiner: {
ModelType.Main: StableDiffusionXLModel,
ModelType.Vae: VaeModel,
# will not work until support written
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
#BaseModelType.Kandinsky2_1: {
# ModelType.Main: Kandinsky2_1Model,
# ModelType.MoVQ: MoVQModel,

View File

@ -24,6 +24,8 @@ class ModelNotFoundException(Exception):
class BaseModelType(str, Enum):
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
#Kandinsky2_1 = "kandinsky-2.1"
class ModelType(str, Enum):
@ -36,7 +38,9 @@ class ModelType(str, Enum):
class SubModelType(str, Enum):
UNet = "unet"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Vae = "vae"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"

View File

@ -0,0 +1,114 @@
import os
import json
from enum import Enum
from pydantic import Field
from typing import Literal, Optional
from .base import (
ModelConfigBase,
BaseModelType,
ModelType,
ModelVariantType,
DiffusersModel,
read_checkpoint_meta,
classproperty,
)
from omegaconf import OmegaConf
class StableDiffusionXLModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusionXLModel(DiffusersModel):
# TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase):
model_format: Literal[StableDiffusionXLModelFormat.Diffusers]
vae: Optional[str] = Field(None)
variant: ModelVariantType
class CheckpointConfig(ModelConfigBase):
model_format: Literal[StableDiffusionXLModelFormat.Checkpoint]
vae: Optional[str] = Field(None)
config: str
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}
assert model_type == ModelType.Main
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusionXL,
model_type=ModelType.Main,
)
@classmethod
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None)
if model_format == StableDiffusionXLModelFormat.Checkpoint:
if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
else:
checkpoint = read_checkpoint_meta(path)
checkpoint = checkpoint.get('state_dict', checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == StableDiffusionXLModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
unet_config = json.loads(f.read())
in_channels = unet_config['in_channels']
else:
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
else:
raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}")
if in_channels == 9:
variant = ModelVariantType.Inpaint
elif in_channels == 5:
variant = ModelVariantType.Depth
elif in_channels == 4:
variant = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion 2.* model format")
if ckpt_config_path is None:
# TO DO: implement picking
pass
return cls.create_config(
path=path,
model_format=model_format,
config=ckpt_config_path,
variant=variant,
)
@classproperty
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
if os.path.isdir(model_path):
return StableDiffusionXLModelFormat.Diffusers
else:
return StableDiffusionXLModelFormat.Checkpoint
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
if isinstance(config, cls.CheckpointConfig):
raise NotImplementedError('conversion of SDXL checkpoint models to diffusers format is not yet supported')
else:
return model_path

View File

@ -5,14 +5,11 @@ from pydantic import Field
from pathlib import Path
from typing import Literal, Optional, Union
from .base import (
ModelBase,
ModelConfigBase,
BaseModelType,
ModelType,
SubModelType,
ModelVariantType,
DiffusersModel,
SchedulerPredictionType,
SilenceWarnings,
read_checkpoint_meta,
classproperty,
@ -248,6 +245,12 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512)
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
ModelVariantType.Depth: "v2-midas-inference.yaml",
},
# note that these .yaml files don't yet exist!
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: "xl-inference-v.yaml",
ModelVariantType.Inpaint: "xl-inpainting-inference.yaml",
ModelVariantType.Depth: "xl-midas-inference.yaml",
}
}
@ -263,6 +266,7 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
# TODO: rework
# Note that convert_ckpt_to_diffuses does not currently support conversion of SDXL models
def _convert_ckpt_and_cache(
version: BaseModelType,
model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig],

View File

@ -1,4 +1,10 @@
# This file predefines a few models that the user may want to install.
sd-1/main/stable-diffusion-xdl-base:
description: Stable Diffusion XL base model - NOT YET RELEASED!! (70 GB)
repo_id: stabilityai/stable-diffusion-xl-base
sd-1/main/stable-diffusion-xdl-refiner:
description: Stable Diffusion XL refiner model - NOT YET RELEASED!! (60 GB)
repo_id: stabilityai/stable-diffusion-xl-refiner
sd-1/main/stable-diffusion-v1-5:
description: Stable Diffusion version 1.5 diffusers model (4.27 GB)
repo_id: runwayml/stable-diffusion-v1-5

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -14,6 +14,14 @@ export const clipSkipMap = {
maxClip: 24,
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
},
sdxl: {
maxClip: 24,
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
},
'sdxl-refiner': {
maxClip: 24,
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
},
};
export default function ParamClipSkip() {

View File

@ -1,4 +1,6 @@
export const MODEL_TYPE_MAP = {
'sd-1': 'Stable Diffusion 1.x',
'sd-2': 'Stable Diffusion 2.x',
sdxl: 'Stable Diffusion XL',
'sdxl-refiner': 'Stable Diffusion XL Refiner',
};

View File

@ -126,7 +126,7 @@ export type HeightParam = z.infer<typeof zHeight>;
export const isValidHeight = (val: unknown): val is HeightParam =>
zHeight.safeParse(val).success;
const zBaseModel = z.enum(['sd-1', 'sd-2']);
const zBaseModel = z.enum(['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
export type BaseModelParam = z.infer<typeof zBaseModel>;

View File

@ -19,18 +19,16 @@ export default function ScanAdvancedAddModels() {
const [advancedAddMode, setAdvancedAddMode] =
useState<ManualAddMode>('diffusers');
const [isCheckpoint, setIsCheckpoint] = useState(
advancedAddScanModel &&
['.ckpt', '.safetensors', '.pth', '.pt'].some((ext) =>
advancedAddScanModel.endsWith(ext)
)
);
const [isCheckpoint, setIsCheckpoint] = useState<boolean>(true);
useEffect(() => {
isCheckpoint
advancedAddScanModel &&
['.ckpt', '.safetensors', '.pth', '.pt'].some((ext) =>
advancedAddScanModel.endsWith(ext)
)
? setAdvancedAddMode('checkpoint')
: setAdvancedAddMode('diffusers');
}, [setAdvancedAddMode, isCheckpoint]);
}, [advancedAddScanModel, setAdvancedAddMode, isCheckpoint]);
const dispatch = useAppDispatch();

View File

@ -58,10 +58,10 @@ export default function MergeModelsPanel() {
}, [sd1DiffusersModels, sd2DiffusersModels]);
const [modelOne, setModelOne] = useState<string | null>(
Object.keys(modelsMap[baseModel])[0]
Object.keys(modelsMap[baseModel as keyof typeof modelsMap])[0]
);
const [modelTwo, setModelTwo] = useState<string | null>(
Object.keys(modelsMap[baseModel])[1]
Object.keys(modelsMap[baseModel as keyof typeof modelsMap])[1]
);
const [modelThree, setModelThree] = useState<string | null>(null);
@ -89,9 +89,9 @@ export default function MergeModelsPanel() {
modelsMap[baseModel as keyof typeof modelsMap]
).filter((model) => model !== modelOne && model !== modelThree);
const modelThreeList = Object.keys(modelsMap[baseModel]).filter(
(model) => model !== modelOne && model !== modelTwo
);
const modelThreeList = Object.keys(
modelsMap[baseModel as keyof typeof modelsMap]
).filter((model) => model !== modelOne && model !== modelTwo);
const handleBaseModelChange = (v: string) => {
setBaseModel(v as BaseModelType);

View File

@ -1,4 +1,4 @@
import { Divider, Flex, Text } from '@chakra-ui/react';
import { Badge, Divider, Flex, Text } from '@chakra-ui/react';
import { useForm } from '@mantine/form';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@ -115,7 +115,20 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
{MODEL_TYPE_MAP[model.base_model]} Model
</Text>
</Flex>
<ModelConvert model={model} />
{!['sdxl', 'sdxl-refiner'].includes(model.base_model) ? (
<ModelConvert model={model} />
) : (
<Badge
sx={{
p: 2,
borderRadius: 4,
bg: 'error.200',
_dark: { bg: 'error.400' },
}}
>
Conversion Not Supported
</Badge>
)}
</Flex>
<Divider />

View File

@ -8,6 +8,8 @@ import { useTranslation } from 'react-i18next';
const baseModelSelectData: IAISelectDataType[] = [
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
{ value: 'sdxl', label: MODEL_TYPE_MAP['sdxl'] },
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
];
type BaseModelSelectProps = Omit<IAISelectProps, 'data'>;

View File

@ -316,7 +316,7 @@ export type components = {
* @description An enumeration.
* @enum {string}
*/
BaseModelType: "sd-1" | "sd-2";
BaseModelType: "sd-1" | "sd-2" | "sdxl" | "sdxl-refiner";
/** BoardChanges */
BoardChanges: {
/**
@ -1219,7 +1219,7 @@ export type components = {
* @description The nodes in this graph
*/
nodes?: {
[key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["RealESRGANInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
[key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["RealESRGANInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
};
/**
* Edges
@ -1262,7 +1262,7 @@ export type components = {
* @description The results of node executions
*/
results: {
[key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["MetadataAccumulatorOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
[key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["MetadataAccumulatorOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
};
/**
* Errors
@ -2060,6 +2060,12 @@ export type components = {
* @default false
*/
tiled?: boolean;
/**
* Fp32
* @description Decode in full precision
* @default false
*/
fp32?: boolean;
};
/**
* ImageUrlsDTO
@ -2515,6 +2521,12 @@ export type components = {
* @default false
*/
tiled?: boolean;
/**
* Fp32
* @description Decode in full precision
* @default false
*/
fp32?: boolean;
/**
* Metadata
* @description Optional core metadata to be written to the image
@ -3335,7 +3347,7 @@ export type components = {
/** ModelsList */
ModelsList: {
/** Models */
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[];
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"])[];
};
/**
* MultiplyInvocation
@ -3763,6 +3775,56 @@ export type components = {
*/
prompt: string;
};
/**
* PromptsFromFileInvocation
* @description Loads prompts from a text file
*/
PromptsFromFileInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default prompt_from_file
* @enum {string}
*/
type?: "prompt_from_file";
/**
* File Path
* @description Path to prompt text file
*/
file_path: string;
/**
* Pre Prompt
* @description String to prepend to each prompt
*/
pre_prompt?: string;
/**
* Post Prompt
* @description String to append to each prompt
*/
post_prompt?: string;
/**
* Start Line
* @description Line in the file to start start from
* @default 1
*/
start_line?: number;
/**
* Max Prompts
* @description Max lines to read from file (0=all)
* @default 1
*/
max_prompts?: number;
};
/**
* RandomIntInvocation
* @description Outputs a single random integer.
@ -4024,6 +4086,533 @@ export type components = {
* @enum {string}
*/
ResourceOrigin: "internal" | "external";
/**
* SDXLCompelPromptInvocation
* @description Parse prompt using compel package to conditioning.
*/
SDXLCompelPromptInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default sdxl_compel_prompt
* @enum {string}
*/
type?: "sdxl_compel_prompt";
/**
* Prompt
* @description Prompt
* @default
*/
prompt?: string;
/**
* Style
* @description Style prompt
* @default
*/
style?: string;
/**
* Original Width
* @default 1024
*/
original_width?: number;
/**
* Original Height
* @default 1024
*/
original_height?: number;
/**
* Crop Top
* @default 0
*/
crop_top?: number;
/**
* Crop Left
* @default 0
*/
crop_left?: number;
/**
* Target Width
* @default 1024
*/
target_width?: number;
/**
* Target Height
* @default 1024
*/
target_height?: number;
/**
* Clip1
* @description Clip to use
*/
clip1?: components["schemas"]["ClipField"];
/**
* Clip2
* @description Clip to use
*/
clip2?: components["schemas"]["ClipField"];
};
/**
* SDXLLatentsToLatentsInvocation
* @description Generates latents from conditionings.
*/
SDXLLatentsToLatentsInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default l2l_sdxl
* @enum {string}
*/
type?: "l2l_sdxl";
/**
* Positive Conditioning
* @description Positive conditioning for generation
*/
positive_conditioning?: components["schemas"]["ConditioningField"];
/**
* Negative Conditioning
* @description Negative conditioning for generation
*/
negative_conditioning?: components["schemas"]["ConditioningField"];
/**
* Noise
* @description The noise to use
*/
noise?: components["schemas"]["LatentsField"];
/**
* Steps
* @description The number of steps to use to generate the image
* @default 10
*/
steps?: number;
/**
* Cfg Scale
* @description The Classifier-Free Guidance, higher values may result in a result closer to the prompt
* @default 7.5
*/
cfg_scale?: number | (number)[];
/**
* Scheduler
* @description The scheduler to use
* @default euler
* @enum {string}
*/
scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc";
/**
* Unet
* @description UNet submodel
*/
unet?: components["schemas"]["UNetField"];
/**
* Latents
* @description Initial latents
*/
latents?: components["schemas"]["LatentsField"];
/**
* Denoising Start
* @default 0
*/
denoising_start?: number;
/**
* Denoising End
* @default 1
*/
denoising_end?: number;
};
/**
* SDXLModelLoaderInvocation
* @description Loads an sdxl base model, outputting its submodels.
*/
SDXLModelLoaderInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default sdxl_model_loader
* @enum {string}
*/
type?: "sdxl_model_loader";
/**
* Model
* @description The model to load
*/
model: components["schemas"]["MainModelField"];
};
/**
* SDXLModelLoaderOutput
* @description SDXL base model loader output
*/
SDXLModelLoaderOutput: {
/**
* Type
* @default sdxl_model_loader_output
* @enum {string}
*/
type?: "sdxl_model_loader_output";
/**
* Unet
* @description UNet submodel
*/
unet?: components["schemas"]["UNetField"];
/**
* Clip
* @description Tokenizer and text_encoder submodels
*/
clip?: components["schemas"]["ClipField"];
/**
* Clip2
* @description Tokenizer and text_encoder submodels
*/
clip2?: components["schemas"]["ClipField"];
/**
* Vae
* @description Vae submodel
*/
vae?: components["schemas"]["VaeField"];
};
/**
* SDXLRawPromptInvocation
* @description Parse prompt using compel package to conditioning.
*/
SDXLRawPromptInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default sdxl_raw_prompt
* @enum {string}
*/
type?: "sdxl_raw_prompt";
/**
* Prompt
* @description Prompt
* @default
*/
prompt?: string;
/**
* Style
* @description Style prompt
* @default
*/
style?: string;
/**
* Original Width
* @default 1024
*/
original_width?: number;
/**
* Original Height
* @default 1024
*/
original_height?: number;
/**
* Crop Top
* @default 0
*/
crop_top?: number;
/**
* Crop Left
* @default 0
*/
crop_left?: number;
/**
* Target Width
* @default 1024
*/
target_width?: number;
/**
* Target Height
* @default 1024
*/
target_height?: number;
/**
* Clip1
* @description Clip to use
*/
clip1?: components["schemas"]["ClipField"];
/**
* Clip2
* @description Clip to use
*/
clip2?: components["schemas"]["ClipField"];
};
/**
* SDXLRefinerCompelPromptInvocation
* @description Parse prompt using compel package to conditioning.
*/
SDXLRefinerCompelPromptInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default sdxl_refiner_compel_prompt
* @enum {string}
*/
type?: "sdxl_refiner_compel_prompt";
/**
* Style
* @description Style prompt
* @default
*/
style?: string;
/**
* Original Width
* @default 1024
*/
original_width?: number;
/**
* Original Height
* @default 1024
*/
original_height?: number;
/**
* Crop Top
* @default 0
*/
crop_top?: number;
/**
* Crop Left
* @default 0
*/
crop_left?: number;
/**
* Aesthetic Score
* @default 6
*/
aesthetic_score?: number;
/**
* Clip2
* @description Clip to use
*/
clip2?: components["schemas"]["ClipField"];
};
/**
* SDXLRefinerModelLoaderInvocation
* @description Loads an sdxl refiner model, outputting its submodels.
*/
SDXLRefinerModelLoaderInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default sdxl_refiner_model_loader
* @enum {string}
*/
type?: "sdxl_refiner_model_loader";
/**
* Model
* @description The model to load
*/
model: components["schemas"]["MainModelField"];
};
/**
* SDXLRefinerModelLoaderOutput
* @description SDXL refiner model loader output
*/
SDXLRefinerModelLoaderOutput: {
/**
* Type
* @default sdxl_refiner_model_loader_output
* @enum {string}
*/
type?: "sdxl_refiner_model_loader_output";
/**
* Unet
* @description UNet submodel
*/
unet?: components["schemas"]["UNetField"];
/**
* Clip2
* @description Tokenizer and text_encoder submodels
*/
clip2?: components["schemas"]["ClipField"];
/**
* Vae
* @description Vae submodel
*/
vae?: components["schemas"]["VaeField"];
};
/**
* SDXLRefinerRawPromptInvocation
* @description Parse prompt using compel package to conditioning.
*/
SDXLRefinerRawPromptInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default sdxl_refiner_raw_prompt
* @enum {string}
*/
type?: "sdxl_refiner_raw_prompt";
/**
* Style
* @description Style prompt
* @default
*/
style?: string;
/**
* Original Width
* @default 1024
*/
original_width?: number;
/**
* Original Height
* @default 1024
*/
original_height?: number;
/**
* Crop Top
* @default 0
*/
crop_top?: number;
/**
* Crop Left
* @default 0
*/
crop_left?: number;
/**
* Aesthetic Score
* @default 6
*/
aesthetic_score?: number;
/**
* Clip2
* @description Clip to use
*/
clip2?: components["schemas"]["ClipField"];
};
/**
* SDXLTextToLatentsInvocation
* @description Generates latents from conditionings.
*/
SDXLTextToLatentsInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default t2l_sdxl
* @enum {string}
*/
type?: "t2l_sdxl";
/**
* Positive Conditioning
* @description Positive conditioning for generation
*/
positive_conditioning?: components["schemas"]["ConditioningField"];
/**
* Negative Conditioning
* @description Negative conditioning for generation
*/
negative_conditioning?: components["schemas"]["ConditioningField"];
/**
* Noise
* @description The noise to use
*/
noise?: components["schemas"]["LatentsField"];
/**
* Steps
* @description The number of steps to use to generate the image
* @default 10
*/
steps?: number;
/**
* Cfg Scale
* @description The Classifier-Free Guidance, higher values may result in a result closer to the prompt
* @default 7.5
*/
cfg_scale?: number | (number)[];
/**
* Scheduler
* @description The scheduler to use
* @default euler
* @enum {string}
*/
scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc";
/**
* Unet
* @description UNet submodel
*/
unet?: components["schemas"]["UNetField"];
/**
* Denoising End
* @default 1
*/
denoising_end?: number;
};
/**
* ScaleLatentsInvocation
* @description Scales latents by a given factor.
@ -4226,6 +4815,56 @@ export type components = {
vae?: string;
variant: components["schemas"]["ModelVariantType"];
};
/** StableDiffusionXLModelCheckpointConfig */
StableDiffusionXLModelCheckpointConfig: {
/** Model Name */
model_name: string;
base_model: components["schemas"]["BaseModelType"];
/**
* Model Type
* @enum {string}
*/
model_type: "main";
/** Path */
path: string;
/** Description */
description?: string;
/**
* Model Format
* @enum {string}
*/
model_format: "checkpoint";
error?: components["schemas"]["ModelError"];
/** Vae */
vae?: string;
/** Config */
config: string;
variant: components["schemas"]["ModelVariantType"];
};
/** StableDiffusionXLModelDiffusersConfig */
StableDiffusionXLModelDiffusersConfig: {
/** Model Name */
model_name: string;
base_model: components["schemas"]["BaseModelType"];
/**
* Model Type
* @enum {string}
*/
model_type: "main";
/** Path */
path: string;
/** Description */
description?: string;
/**
* Model Format
* @enum {string}
*/
model_format: "diffusers";
error?: components["schemas"]["ModelError"];
/** Vae */
vae?: string;
variant: components["schemas"]["ModelVariantType"];
};
/**
* StepParamEasingInvocation
* @description Experimental per-step parameter easing for denoising steps
@ -4313,7 +4952,7 @@ export type components = {
* @description An enumeration.
* @enum {string}
*/
SubModelType: "unet" | "text_encoder" | "tokenizer" | "vae" | "scheduler" | "safety_checker";
SubModelType: "unet" | "text_encoder" | "text_encoder_2" | "tokenizer" | "tokenizer_2" | "vae" | "scheduler" | "safety_checker";
/**
* SubtractInvocation
* @description Subtracts two numbers
@ -4612,18 +5251,24 @@ export type components = {
*/
image?: components["schemas"]["ImageField"];
};
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionXLModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
};
responses: never;
parameters: never;
@ -4734,7 +5379,7 @@ export type operations = {
};
requestBody: {
content: {
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["RealESRGANInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["RealESRGANInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
};
};
responses: {
@ -4771,7 +5416,7 @@ export type operations = {
};
requestBody: {
content: {
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["RealESRGANInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["RealESRGANInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
};
};
responses: {
@ -4969,8 +5614,8 @@ export type operations = {
list_models: {
parameters: {
query?: {
/** @description Base model */
base_model?: components["schemas"]["BaseModelType"];
/** @description Base models to include */
base_models?: (components["schemas"]["BaseModelType"])[];
/** @description The type of model to get */
model_type?: components["schemas"]["ModelType"];
};
@ -5035,14 +5680,14 @@ export type operations = {
};
requestBody: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
};
};
responses: {
/** @description The model was updated successfully */
200: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
};
};
/** @description Bad request */
@ -5073,13 +5718,15 @@ export type operations = {
/** @description The model imported successfully */
201: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
};
};
/** @description The model could not be found */
404: never;
/** @description There is already a model corresponding to this path or repo_id */
409: never;
/** @description Unrecognized file/folder format */
415: never;
/** @description Validation Error */
422: {
content: {
@ -5097,14 +5744,14 @@ export type operations = {
add_model: {
requestBody: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
};
};
responses: {
/** @description The model added successfully */
201: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
};
};
/** @description The model could not be found */
@ -5144,7 +5791,7 @@ export type operations = {
/** @description Model converted successfully */
200: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
};
};
/** @description Bad request */
@ -5233,7 +5880,7 @@ export type operations = {
/** @description Model converted successfully */
200: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"];
};
};
/** @description Incompatible models */

View File

@ -47,10 +47,12 @@ export type TextualInversionModelConfig =
components['schemas']['TextualInversionModelConfig'];
export type DiffusersModelConfig =
| components['schemas']['StableDiffusion1ModelDiffusersConfig']
| components['schemas']['StableDiffusion2ModelDiffusersConfig'];
| components['schemas']['StableDiffusion2ModelDiffusersConfig']
| components['schemas']['StableDiffusionXLModelDiffusersConfig'];
export type CheckpointModelConfig =
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
| components['schemas']['StableDiffusion2ModelCheckpointConfig'];
| components['schemas']['StableDiffusion2ModelCheckpointConfig']
| components['schemas']['StableDiffusionXLModelCheckpointConfig'];
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
export type AnyModelConfig =
| LoRAModelConfig

View File

@ -38,7 +38,7 @@ dependencies = [
"albumentations",
"click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel>=1.2.1",
"compel==2.0.0rc2",
"controlnet-aux>=0.0.6",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"datasets",