mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add model loader node; unet, clip, vae fields; change compel node to clip field
This commit is contained in:
parent
131145eab1
commit
3b2a054f7a
@ -3,6 +3,8 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||||
|
|
||||||
|
from .model import ClipField
|
||||||
|
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||||
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
|
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
|
||||||
@ -41,7 +43,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
type: Literal["compel"] = "compel"
|
type: Literal["compel"] = "compel"
|
||||||
|
|
||||||
prompt: str = Field(default="", description="Prompt")
|
prompt: str = Field(default="", description="Prompt")
|
||||||
model: str = Field(default="", description="Model to use")
|
clip: ClipField = Field(None, description="Clip to use")
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
@ -58,12 +60,15 @@ class CompelInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
|
|
||||||
# TODO: load without model
|
# TODO: load without model
|
||||||
model = context.services.model_manager.get_model(self.model)
|
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
text_encoder_info = context.services.model_manager.get_model(
|
||||||
self.model, SDModelType.diffusers, SDModelType.text_encoder
|
model_name=self.clip.text_encoder.model_name,
|
||||||
|
model_type=SDModelType[self.clip.text_encoder.model_type],
|
||||||
|
submodel=SDModelType[self.clip.text_encoder.submodel],
|
||||||
)
|
)
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
self.model, SDModelType.diffusers, SDModelType.tokenizer
|
model_name=self.clip.tokenizer.model_name,
|
||||||
|
model_type=SDModelType[self.clip.tokenizer.model_type],
|
||||||
|
submodel=SDModelType[self.clip.tokenizer.submodel],
|
||||||
)
|
)
|
||||||
with text_encoder_info.context as text_encoder,\
|
with text_encoder_info.context as text_encoder,\
|
||||||
tokenizer_info.context as tokenizer:
|
tokenizer_info.context as tokenizer:
|
||||||
|
131
invokeai/app/invocations/model.py
Normal file
131
invokeai/app/invocations/model.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
from typing import Literal, Optional, Union
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||||
|
|
||||||
|
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||||
|
from ...backend.model_management import SDModelType
|
||||||
|
|
||||||
|
class ModelInfo(BaseModel):
|
||||||
|
model_name: str = Field(description="Info to load unet submodel")
|
||||||
|
model_type: str = Field(description="Info to load unet submodel")
|
||||||
|
submodel: Optional[str] = Field(description="Info to load unet submodel")
|
||||||
|
|
||||||
|
class UNetField(BaseModel):
|
||||||
|
unet: ModelInfo = Field(description="Info to load unet submodel")
|
||||||
|
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
|
||||||
|
# loras: List[ModelInfo]
|
||||||
|
|
||||||
|
class ClipField(BaseModel):
|
||||||
|
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
|
||||||
|
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
|
||||||
|
# loras: List[ModelInfo]
|
||||||
|
|
||||||
|
class VaeField(BaseModel):
|
||||||
|
# TODO: better naming?
|
||||||
|
vae: ModelInfo = Field(description="Info to load vae submodel")
|
||||||
|
|
||||||
|
|
||||||
|
class ModelLoaderOutput(BaseInvocationOutput):
|
||||||
|
"""Model loader output"""
|
||||||
|
|
||||||
|
#fmt: off
|
||||||
|
type: Literal["model_loader_output"] = "model_loader_output"
|
||||||
|
|
||||||
|
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||||
|
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||||
|
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
class ModelLoaderInvocation(BaseInvocation):
|
||||||
|
"""Loading submodels of selected model."""
|
||||||
|
|
||||||
|
type: Literal["model_loader"] = "model_loader"
|
||||||
|
|
||||||
|
model_name: str = Field(default="", description="Model to load")
|
||||||
|
# TODO: precision?
|
||||||
|
|
||||||
|
# Schema customisation
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"tags": ["model", "loader"],
|
||||||
|
"type_hints": {
|
||||||
|
"model_name": "model" # TODO: rename to model_name?
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||||
|
|
||||||
|
# TODO: not found exceptions
|
||||||
|
if not context.services.model_manager.valid_model(
|
||||||
|
model_name=self.model_name,
|
||||||
|
model_type=SDModelType.diffusers,
|
||||||
|
):
|
||||||
|
raise Exception(f"Unkown model name: {self.model_name}!")
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not context.services.model_manager.valid_model(
|
||||||
|
model_name=self.model_name,
|
||||||
|
model_type=SDModelType.diffusers,
|
||||||
|
submodel=SDModelType.tokenizer,
|
||||||
|
):
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not context.services.model_manager.valid_model(
|
||||||
|
model_name=self.model_name,
|
||||||
|
model_type=SDModelType.diffusers,
|
||||||
|
submodel=SDModelType.text_encoder,
|
||||||
|
):
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not context.services.model_manager.valid_model(
|
||||||
|
model_name=self.model_name,
|
||||||
|
model_type=SDModelType.diffusers,
|
||||||
|
submodel=SDModelType.unet,
|
||||||
|
):
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
return ModelLoaderOutput(
|
||||||
|
unet=UNetField(
|
||||||
|
unet=ModelInfo(
|
||||||
|
model_name=self.model_name,
|
||||||
|
model_type=SDModelType.diffusers.name,
|
||||||
|
submodel=SDModelType.unet.name,
|
||||||
|
),
|
||||||
|
scheduler=ModelInfo(
|
||||||
|
model_name=self.model_name,
|
||||||
|
model_type=SDModelType.diffusers.name,
|
||||||
|
submodel=SDModelType.scheduler.name,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
clip=ClipField(
|
||||||
|
tokenizer=ModelInfo(
|
||||||
|
model_name=self.model_name,
|
||||||
|
model_type=SDModelType.diffusers.name,
|
||||||
|
submodel=SDModelType.tokenizer.name,
|
||||||
|
),
|
||||||
|
text_encoder=ModelInfo(
|
||||||
|
model_name=self.model_name,
|
||||||
|
model_type=SDModelType.diffusers.name,
|
||||||
|
submodel=SDModelType.text_encoder.name,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
vae=VaeField(
|
||||||
|
vae=ModelInfo(
|
||||||
|
model_name=self.model_name,
|
||||||
|
model_type=SDModelType.diffusers.name,
|
||||||
|
submodel=SDModelType.vae.name,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
@ -236,7 +236,7 @@ class ModelManager(object):
|
|||||||
Given a model name, returns True if it is a valid
|
Given a model name, returns True if it is a valid
|
||||||
identifier.
|
identifier.
|
||||||
"""
|
"""
|
||||||
model_key = self.create_key(model_name, model_class)
|
model_key = self.create_key(model_name, model_type)
|
||||||
return model_key in self.config
|
return model_key in self.config
|
||||||
|
|
||||||
def create_key(self, model_name: str, model_type: SDModelType) -> str:
|
def create_key(self, model_name: str, model_type: SDModelType) -> str:
|
||||||
|
@ -1,5 +1,12 @@
|
|||||||
import { forEach, size } from 'lodash-es';
|
import { forEach, size } from 'lodash-es';
|
||||||
import { ImageField, LatentsField, ConditioningField } from 'services/api';
|
import {
|
||||||
|
ImageField,
|
||||||
|
LatentsField,
|
||||||
|
ConditioningField,
|
||||||
|
UNetField,
|
||||||
|
ClipField,
|
||||||
|
VaeField,
|
||||||
|
} from 'services/api';
|
||||||
|
|
||||||
const OBJECT_TYPESTRING = '[object Object]';
|
const OBJECT_TYPESTRING = '[object Object]';
|
||||||
const STRING_TYPESTRING = '[object String]';
|
const STRING_TYPESTRING = '[object String]';
|
||||||
@ -98,6 +105,101 @@ const parseConditioningField = (
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const _parseModelInfo = (modelInfo: unknown): ModelInfo | undefined => {
|
||||||
|
// Must be an object
|
||||||
|
if (!isObject(modelInfo)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!('model_name' in modelInfo && typeof modelInfo.model_name == 'string')) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!('model_type' in modelInfo && typeof modelInfo.model_type == 'string')) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!('submodel' in modelInfo && typeof modelInfo.submodel == 'string')) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
model_name: modelInfo.model_name,
|
||||||
|
model_type: modelInfo.model_type,
|
||||||
|
submodel: modelInfo.submodel,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
const parseUNetField = (unetField: unknown): UNetField | undefined => {
|
||||||
|
// Must be an object
|
||||||
|
if (!isObject(unetField)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!('unet' in unetField && 'scheduler' in unetField)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const unet = _parseModelInfo(unetField.unet);
|
||||||
|
const scheduler = _parseModelInfo(unetField.scheduler);
|
||||||
|
|
||||||
|
if (!(unet && scheduler)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a valid UNetField
|
||||||
|
return {
|
||||||
|
unet: unet,
|
||||||
|
scheduler: scheduler,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
const parseClipField = (clipField: unknown): ClipField | undefined => {
|
||||||
|
// Must be an object
|
||||||
|
if (!isObject(clipField)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!('tokenizer' in clipField && 'text_encoder' in clipField)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const tokenizer = _parseModelInfo(clipField.tokenizer);
|
||||||
|
const text_encoder = _parseModelInfo(clipField.text_encoder);
|
||||||
|
|
||||||
|
if (!(tokenizer && text_encoder)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a valid ClipField
|
||||||
|
return {
|
||||||
|
tokenizer: tokenizer,
|
||||||
|
text_encoder: text_encoder,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
const parseVaeField = (vaeField: unknown): VaeField | undefined => {
|
||||||
|
// Must be an object
|
||||||
|
if (!isObject(vaeField)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!('vae' in vaeField)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const vae = _parseModelInfo(vaeField.vae);
|
||||||
|
|
||||||
|
if (!vae) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a valid VaeField
|
||||||
|
return {
|
||||||
|
vae: vae,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
type NodeMetadata = {
|
type NodeMetadata = {
|
||||||
[key: string]:
|
[key: string]:
|
||||||
| string
|
| string
|
||||||
@ -105,7 +207,10 @@ type NodeMetadata = {
|
|||||||
| boolean
|
| boolean
|
||||||
| ImageField
|
| ImageField
|
||||||
| LatentsField
|
| LatentsField
|
||||||
| ConditioningField;
|
| ConditioningField
|
||||||
|
| UNetField
|
||||||
|
| ClipField
|
||||||
|
| VaeField;
|
||||||
};
|
};
|
||||||
|
|
||||||
type InvokeAIMetadata = {
|
type InvokeAIMetadata = {
|
||||||
@ -131,7 +236,8 @@ export const parseNodeMetadata = (
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// the only valid object types are ImageField, LatentsField and ConditioningField
|
// valid object types are:
|
||||||
|
// ImageField, LatentsField ConditioningField, UNetField, ClipField, VaeField
|
||||||
if (isObject(nodeItem)) {
|
if (isObject(nodeItem)) {
|
||||||
if ('image_name' in nodeItem || 'image_type' in nodeItem) {
|
if ('image_name' in nodeItem || 'image_type' in nodeItem) {
|
||||||
const imageField = parseImageField(nodeItem);
|
const imageField = parseImageField(nodeItem);
|
||||||
@ -156,6 +262,27 @@ export const parseNodeMetadata = (
|
|||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ('unet' in nodeItem && 'tokenizer' in nodeItem) {
|
||||||
|
const unetField = parseUNetField(nodeItem);
|
||||||
|
if (unetField) {
|
||||||
|
parsed[nodeKey] = unetField;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ('tokenizer' in nodeItem && 'text_encoder' in nodeItem) {
|
||||||
|
const clipField = parseClipField(nodeItem);
|
||||||
|
if (clipField) {
|
||||||
|
parsed[nodeKey] = clipField;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ('vae' in nodeItem) {
|
||||||
|
const vaeField = parseVaeField(nodeItem);
|
||||||
|
if (vaeField) {
|
||||||
|
parsed[nodeKey] = vaeField;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// otherwise we accept any string, number or boolean
|
// otherwise we accept any string, number or boolean
|
||||||
|
@ -7,6 +7,9 @@ import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
|
|||||||
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
||||||
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
|
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
|
||||||
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
|
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
|
||||||
|
import UNetInputFieldComponent from './fields/UNetInputFieldComponent';
|
||||||
|
import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
|
||||||
|
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
|
||||||
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
|
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
|
||||||
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
|
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
|
||||||
import StringInputFieldComponent from './fields/StringInputFieldComponent';
|
import StringInputFieldComponent from './fields/StringInputFieldComponent';
|
||||||
@ -97,6 +100,36 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (type === 'unet' && template.type === 'unet') {
|
||||||
|
return (
|
||||||
|
<UNetInputFieldComponent
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={field}
|
||||||
|
template={template}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type === 'clip' && template.type === 'clip') {
|
||||||
|
return (
|
||||||
|
<ClipInputFieldComponent
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={field}
|
||||||
|
template={template}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type === 'vae' && template.type === 'vae') {
|
||||||
|
return (
|
||||||
|
<VaeInputFieldComponent
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={field}
|
||||||
|
template={template}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if (type === 'model' && template.type === 'model') {
|
if (type === 'model' && template.type === 'model') {
|
||||||
return (
|
return (
|
||||||
<ModelInputFieldComponent
|
<ModelInputFieldComponent
|
||||||
|
@ -0,0 +1,16 @@
|
|||||||
|
import {
|
||||||
|
ClipInputFieldTemplate,
|
||||||
|
ClipInputFieldValue,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import { FieldComponentProps } from './types';
|
||||||
|
|
||||||
|
const ClipInputFieldComponent = (
|
||||||
|
props: FieldComponentProps<ClipInputFieldValue, ClipInputFieldTemplate>
|
||||||
|
) => {
|
||||||
|
const { nodeId, field } = props;
|
||||||
|
|
||||||
|
return null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ClipInputFieldComponent);
|
@ -0,0 +1,16 @@
|
|||||||
|
import {
|
||||||
|
UNetInputFieldTemplate,
|
||||||
|
UNetInputFieldValue,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import { FieldComponentProps } from './types';
|
||||||
|
|
||||||
|
const UNetInputFieldComponent = (
|
||||||
|
props: FieldComponentProps<UNetInputFieldValue, UNetInputFieldTemplate>
|
||||||
|
) => {
|
||||||
|
const { nodeId, field } = props;
|
||||||
|
|
||||||
|
return null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(UNetInputFieldComponent);
|
@ -0,0 +1,16 @@
|
|||||||
|
import {
|
||||||
|
VaeInputFieldTemplate,
|
||||||
|
VaeInputFieldValue,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import { FieldComponentProps } from './types';
|
||||||
|
|
||||||
|
const VaeInputFieldComponent = (
|
||||||
|
props: FieldComponentProps<VaeInputFieldValue, VaeInputFieldTemplate>
|
||||||
|
) => {
|
||||||
|
const { nodeId, field } = props;
|
||||||
|
|
||||||
|
return null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(VaeInputFieldComponent);
|
@ -11,6 +11,9 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
|||||||
ImageField: 'image',
|
ImageField: 'image',
|
||||||
LatentsField: 'latents',
|
LatentsField: 'latents',
|
||||||
ConditioningField: 'conditioning',
|
ConditioningField: 'conditioning',
|
||||||
|
UNetField: 'unet',
|
||||||
|
ClipField: 'clip',
|
||||||
|
VaeField: 'vae',
|
||||||
model: 'model',
|
model: 'model',
|
||||||
array: 'array',
|
array: 'array',
|
||||||
item: 'item',
|
item: 'item',
|
||||||
@ -71,6 +74,24 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
|||||||
title: 'Conditioning',
|
title: 'Conditioning',
|
||||||
description: 'Conditioning may be passed between nodes.',
|
description: 'Conditioning may be passed between nodes.',
|
||||||
},
|
},
|
||||||
|
unet: {
|
||||||
|
color: 'red',
|
||||||
|
colorCssVar: getColorTokenCssVariable('red'),
|
||||||
|
title: 'UNet',
|
||||||
|
description: 'UNet submodel.',
|
||||||
|
},
|
||||||
|
clip: {
|
||||||
|
color: 'green',
|
||||||
|
colorCssVar: getColorTokenCssVariable('green'),
|
||||||
|
title: 'Clip',
|
||||||
|
description: 'Tokenizer and text_encoder submodels.',
|
||||||
|
},
|
||||||
|
vae: {
|
||||||
|
color: 'blue',
|
||||||
|
colorCssVar: getColorTokenCssVariable('blue'),
|
||||||
|
title: 'Vae',
|
||||||
|
description: 'Vae submodel.',
|
||||||
|
},
|
||||||
model: {
|
model: {
|
||||||
color: 'teal',
|
color: 'teal',
|
||||||
colorCssVar: getColorTokenCssVariable('teal'),
|
colorCssVar: getColorTokenCssVariable('teal'),
|
||||||
|
@ -58,6 +58,9 @@ export type FieldType =
|
|||||||
| 'image'
|
| 'image'
|
||||||
| 'latents'
|
| 'latents'
|
||||||
| 'conditioning'
|
| 'conditioning'
|
||||||
|
| 'unet'
|
||||||
|
| 'clip'
|
||||||
|
| 'vae'
|
||||||
| 'model'
|
| 'model'
|
||||||
| 'array'
|
| 'array'
|
||||||
| 'item'
|
| 'item'
|
||||||
@ -79,6 +82,9 @@ export type InputFieldValue =
|
|||||||
| ImageInputFieldValue
|
| ImageInputFieldValue
|
||||||
| LatentsInputFieldValue
|
| LatentsInputFieldValue
|
||||||
| ConditioningInputFieldValue
|
| ConditioningInputFieldValue
|
||||||
|
| UNetInputFieldValue
|
||||||
|
| ClipInputFieldValue
|
||||||
|
| VaeInputFieldValue
|
||||||
| EnumInputFieldValue
|
| EnumInputFieldValue
|
||||||
| ModelInputFieldValue
|
| ModelInputFieldValue
|
||||||
| ArrayInputFieldValue
|
| ArrayInputFieldValue
|
||||||
@ -99,6 +105,9 @@ export type InputFieldTemplate =
|
|||||||
| ImageInputFieldTemplate
|
| ImageInputFieldTemplate
|
||||||
| LatentsInputFieldTemplate
|
| LatentsInputFieldTemplate
|
||||||
| ConditioningInputFieldTemplate
|
| ConditioningInputFieldTemplate
|
||||||
|
| UNetInputFieldTemplate
|
||||||
|
| ClipInputFieldTemplate
|
||||||
|
| VaeInputFieldTemplate
|
||||||
| EnumInputFieldTemplate
|
| EnumInputFieldTemplate
|
||||||
| ModelInputFieldTemplate
|
| ModelInputFieldTemplate
|
||||||
| ArrayInputFieldTemplate
|
| ArrayInputFieldTemplate
|
||||||
@ -177,6 +186,21 @@ export type ConditioningInputFieldValue = FieldValueBase & {
|
|||||||
value?: undefined;
|
value?: undefined;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type UNetInputFieldValue = FieldValueBase & {
|
||||||
|
type: 'unet';
|
||||||
|
value?: undefined;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ClipInputFieldValue = FieldValueBase & {
|
||||||
|
type: 'clip';
|
||||||
|
value?: undefined;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type VaeInputFieldValue = FieldValueBase & {
|
||||||
|
type: 'vae';
|
||||||
|
value?: undefined;
|
||||||
|
};
|
||||||
|
|
||||||
export type ImageInputFieldValue = FieldValueBase & {
|
export type ImageInputFieldValue = FieldValueBase & {
|
||||||
type: 'image';
|
type: 'image';
|
||||||
value?: Pick<ImageField, 'image_name' | 'image_type'>;
|
value?: Pick<ImageField, 'image_name' | 'image_type'>;
|
||||||
|
@ -10,6 +10,9 @@ import {
|
|||||||
IntegerInputFieldTemplate,
|
IntegerInputFieldTemplate,
|
||||||
LatentsInputFieldTemplate,
|
LatentsInputFieldTemplate,
|
||||||
ConditioningInputFieldTemplate,
|
ConditioningInputFieldTemplate,
|
||||||
|
UNetInputFieldTemplate,
|
||||||
|
ClipInputFieldTemplate,
|
||||||
|
VaeInputFieldTemplate,
|
||||||
StringInputFieldTemplate,
|
StringInputFieldTemplate,
|
||||||
ModelInputFieldTemplate,
|
ModelInputFieldTemplate,
|
||||||
ArrayInputFieldTemplate,
|
ArrayInputFieldTemplate,
|
||||||
@ -215,6 +218,51 @@ const buildConditioningInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildUNetInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): UNetInputFieldTemplate => {
|
||||||
|
const template: UNetInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'unet',
|
||||||
|
inputRequirement: 'always',
|
||||||
|
inputKind: 'connection',
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildClipInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): ClipInputFieldTemplate => {
|
||||||
|
const template: ClipInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'clip',
|
||||||
|
inputRequirement: 'always',
|
||||||
|
inputKind: 'connection',
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildVaeInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): VaeInputFieldTemplate => {
|
||||||
|
const template: VaeInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'vae',
|
||||||
|
inputRequirement: 'always',
|
||||||
|
inputKind: 'connection',
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildEnumInputFieldTemplate = ({
|
const buildEnumInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -331,6 +379,15 @@ export const buildInputFieldTemplate = (
|
|||||||
if (['conditioning'].includes(fieldType)) {
|
if (['conditioning'].includes(fieldType)) {
|
||||||
return buildConditioningInputFieldTemplate({ schemaObject, baseField });
|
return buildConditioningInputFieldTemplate({ schemaObject, baseField });
|
||||||
}
|
}
|
||||||
|
if (['unet'].includes(fieldType)) {
|
||||||
|
return buildUNetInputFieldTemplate({ schemaObject, baseField });
|
||||||
|
}
|
||||||
|
if (['clip'].includes(fieldType)) {
|
||||||
|
return buildClipInputFieldTemplate({ schemaObject, baseField });
|
||||||
|
}
|
||||||
|
if (['vae'].includes(fieldType)) {
|
||||||
|
return buildVaeInputFieldTemplate({ schemaObject, baseField });
|
||||||
|
}
|
||||||
if (['model'].includes(fieldType)) {
|
if (['model'].includes(fieldType)) {
|
||||||
return buildModelInputFieldTemplate({ schemaObject, baseField });
|
return buildModelInputFieldTemplate({ schemaObject, baseField });
|
||||||
}
|
}
|
||||||
|
@ -52,6 +52,18 @@ export const buildInputFieldValue = (
|
|||||||
fieldValue.value = undefined;
|
fieldValue.value = undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (template.type === 'unet') {
|
||||||
|
fieldValue.value = undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (template.type === 'clip') {
|
||||||
|
fieldValue.value = undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (template.type === 'vae') {
|
||||||
|
fieldValue.value = undefined;
|
||||||
|
}
|
||||||
|
|
||||||
if (template.type === 'model') {
|
if (template.type === 'model') {
|
||||||
fieldValue.value = undefined;
|
fieldValue.value = undefined;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user