diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py
index 25bb421544..584e14ea0f 100644
--- a/invokeai/app/invocations/compel.py
+++ b/invokeai/app/invocations/compel.py
@@ -3,6 +3,8 @@ from pydantic import BaseModel, Field
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
+from .model import ClipField
+
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
@@ -41,7 +43,7 @@ class CompelInvocation(BaseInvocation):
type: Literal["compel"] = "compel"
prompt: str = Field(default="", description="Prompt")
- model: str = Field(default="", description="Model to use")
+ clip: ClipField = Field(None, description="Clip to use")
# Schema customisation
class Config(InvocationConfig):
@@ -58,12 +60,15 @@ class CompelInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> CompelOutput:
# TODO: load without model
- model = context.services.model_manager.get_model(self.model)
text_encoder_info = context.services.model_manager.get_model(
- self.model, SDModelType.diffusers, SDModelType.text_encoder
+ model_name=self.clip.text_encoder.model_name,
+ model_type=SDModelType[self.clip.text_encoder.model_type],
+ submodel=SDModelType[self.clip.text_encoder.submodel],
)
tokenizer_info = context.services.model_manager.get_model(
- self.model, SDModelType.diffusers, SDModelType.tokenizer
+ model_name=self.clip.tokenizer.model_name,
+ model_type=SDModelType[self.clip.tokenizer.model_type],
+ submodel=SDModelType[self.clip.tokenizer.submodel],
)
with text_encoder_info.context as text_encoder,\
tokenizer_info.context as tokenizer:
diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py
new file mode 100644
index 0000000000..11f1fd5c5f
--- /dev/null
+++ b/invokeai/app/invocations/model.py
@@ -0,0 +1,131 @@
+from typing import Literal, Optional, Union
+from pydantic import BaseModel, Field
+
+from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
+
+from ...backend.util.devices import choose_torch_device, torch_dtype
+from ...backend.model_management import SDModelType
+
+class ModelInfo(BaseModel):
+ model_name: str = Field(description="Info to load unet submodel")
+ model_type: str = Field(description="Info to load unet submodel")
+ submodel: Optional[str] = Field(description="Info to load unet submodel")
+
+class UNetField(BaseModel):
+ unet: ModelInfo = Field(description="Info to load unet submodel")
+ scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
+ # loras: List[ModelInfo]
+
+class ClipField(BaseModel):
+ tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
+ text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
+ # loras: List[ModelInfo]
+
+class VaeField(BaseModel):
+ # TODO: better naming?
+ vae: ModelInfo = Field(description="Info to load vae submodel")
+
+
+class ModelLoaderOutput(BaseInvocationOutput):
+ """Model loader output"""
+
+ #fmt: off
+ type: Literal["model_loader_output"] = "model_loader_output"
+
+ unet: UNetField = Field(default=None, description="UNet submodel")
+ clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
+ vae: VaeField = Field(default=None, description="Vae submodel")
+ #fmt: on
+
+
+class ModelLoaderInvocation(BaseInvocation):
+ """Loading submodels of selected model."""
+
+ type: Literal["model_loader"] = "model_loader"
+
+ model_name: str = Field(default="", description="Model to load")
+ # TODO: precision?
+
+ # Schema customisation
+ class Config(InvocationConfig):
+ schema_extra = {
+ "ui": {
+ "tags": ["model", "loader"],
+ "type_hints": {
+ "model_name": "model" # TODO: rename to model_name?
+ }
+ },
+ }
+
+ def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
+
+ # TODO: not found exceptions
+ if not context.services.model_manager.valid_model(
+ model_name=self.model_name,
+ model_type=SDModelType.diffusers,
+ ):
+ raise Exception(f"Unkown model name: {self.model_name}!")
+
+ """
+ if not context.services.model_manager.valid_model(
+ model_name=self.model_name,
+ model_type=SDModelType.diffusers,
+ submodel=SDModelType.tokenizer,
+ ):
+ raise Exception(
+ f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
+ )
+
+ if not context.services.model_manager.valid_model(
+ model_name=self.model_name,
+ model_type=SDModelType.diffusers,
+ submodel=SDModelType.text_encoder,
+ ):
+ raise Exception(
+ f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
+ )
+
+ if not context.services.model_manager.valid_model(
+ model_name=self.model_name,
+ model_type=SDModelType.diffusers,
+ submodel=SDModelType.unet,
+ ):
+ raise Exception(
+ f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
+ )
+ """
+
+
+ return ModelLoaderOutput(
+ unet=UNetField(
+ unet=ModelInfo(
+ model_name=self.model_name,
+ model_type=SDModelType.diffusers.name,
+ submodel=SDModelType.unet.name,
+ ),
+ scheduler=ModelInfo(
+ model_name=self.model_name,
+ model_type=SDModelType.diffusers.name,
+ submodel=SDModelType.scheduler.name,
+ ),
+ ),
+ clip=ClipField(
+ tokenizer=ModelInfo(
+ model_name=self.model_name,
+ model_type=SDModelType.diffusers.name,
+ submodel=SDModelType.tokenizer.name,
+ ),
+ text_encoder=ModelInfo(
+ model_name=self.model_name,
+ model_type=SDModelType.diffusers.name,
+ submodel=SDModelType.text_encoder.name,
+ ),
+ ),
+ vae=VaeField(
+ vae=ModelInfo(
+ model_name=self.model_name,
+ model_type=SDModelType.diffusers.name,
+ submodel=SDModelType.vae.name,
+ ),
+ )
+ )
diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py
index bd70e23f5b..e32ba3b22a 100644
--- a/invokeai/backend/model_management/model_manager.py
+++ b/invokeai/backend/model_management/model_manager.py
@@ -236,7 +236,7 @@ class ModelManager(object):
Given a model name, returns True if it is a valid
identifier.
"""
- model_key = self.create_key(model_name, model_class)
+ model_key = self.create_key(model_name, model_type)
return model_key in self.config
def create_key(self, model_name: str, model_type: SDModelType) -> str:
diff --git a/invokeai/frontend/web/src/common/util/parseMetadata.ts b/invokeai/frontend/web/src/common/util/parseMetadata.ts
index c27833218b..95d08db3e0 100644
--- a/invokeai/frontend/web/src/common/util/parseMetadata.ts
+++ b/invokeai/frontend/web/src/common/util/parseMetadata.ts
@@ -1,5 +1,12 @@
import { forEach, size } from 'lodash-es';
-import { ImageField, LatentsField, ConditioningField } from 'services/api';
+import {
+ ImageField,
+ LatentsField,
+ ConditioningField,
+ UNetField,
+ ClipField,
+ VaeField,
+} from 'services/api';
const OBJECT_TYPESTRING = '[object Object]';
const STRING_TYPESTRING = '[object String]';
@@ -98,6 +105,101 @@ const parseConditioningField = (
};
};
+const _parseModelInfo = (modelInfo: unknown): ModelInfo | undefined => {
+ // Must be an object
+ if (!isObject(modelInfo)) {
+ return;
+ }
+
+ if (!('model_name' in modelInfo && typeof modelInfo.model_name == 'string')) {
+ return;
+ }
+
+ if (!('model_type' in modelInfo && typeof modelInfo.model_type == 'string')) {
+ return;
+ }
+
+ if (!('submodel' in modelInfo && typeof modelInfo.submodel == 'string')) {
+ return;
+ }
+
+ return {
+ model_name: modelInfo.model_name,
+ model_type: modelInfo.model_type,
+ submodel: modelInfo.submodel,
+ };
+};
+
+const parseUNetField = (unetField: unknown): UNetField | undefined => {
+ // Must be an object
+ if (!isObject(unetField)) {
+ return;
+ }
+
+ if (!('unet' in unetField && 'scheduler' in unetField)) {
+ return;
+ }
+
+ const unet = _parseModelInfo(unetField.unet);
+ const scheduler = _parseModelInfo(unetField.scheduler);
+
+ if (!(unet && scheduler)) {
+ return;
+ }
+
+ // Build a valid UNetField
+ return {
+ unet: unet,
+ scheduler: scheduler,
+ };
+};
+
+const parseClipField = (clipField: unknown): ClipField | undefined => {
+ // Must be an object
+ if (!isObject(clipField)) {
+ return;
+ }
+
+ if (!('tokenizer' in clipField && 'text_encoder' in clipField)) {
+ return;
+ }
+
+ const tokenizer = _parseModelInfo(clipField.tokenizer);
+ const text_encoder = _parseModelInfo(clipField.text_encoder);
+
+ if (!(tokenizer && text_encoder)) {
+ return;
+ }
+
+ // Build a valid ClipField
+ return {
+ tokenizer: tokenizer,
+ text_encoder: text_encoder,
+ };
+};
+
+const parseVaeField = (vaeField: unknown): VaeField | undefined => {
+ // Must be an object
+ if (!isObject(vaeField)) {
+ return;
+ }
+
+ if (!('vae' in vaeField)) {
+ return;
+ }
+
+ const vae = _parseModelInfo(vaeField.vae);
+
+ if (!vae) {
+ return;
+ }
+
+ // Build a valid VaeField
+ return {
+ vae: vae,
+ };
+};
+
type NodeMetadata = {
[key: string]:
| string
@@ -105,7 +207,10 @@ type NodeMetadata = {
| boolean
| ImageField
| LatentsField
- | ConditioningField;
+ | ConditioningField
+ | UNetField
+ | ClipField
+ | VaeField;
};
type InvokeAIMetadata = {
@@ -131,7 +236,8 @@ export const parseNodeMetadata = (
return;
}
- // the only valid object types are ImageField, LatentsField and ConditioningField
+ // valid object types are:
+ // ImageField, LatentsField ConditioningField, UNetField, ClipField, VaeField
if (isObject(nodeItem)) {
if ('image_name' in nodeItem || 'image_type' in nodeItem) {
const imageField = parseImageField(nodeItem);
@@ -156,6 +262,27 @@ export const parseNodeMetadata = (
}
return;
}
+
+ if ('unet' in nodeItem && 'tokenizer' in nodeItem) {
+ const unetField = parseUNetField(nodeItem);
+ if (unetField) {
+ parsed[nodeKey] = unetField;
+ }
+ }
+
+ if ('tokenizer' in nodeItem && 'text_encoder' in nodeItem) {
+ const clipField = parseClipField(nodeItem);
+ if (clipField) {
+ parsed[nodeKey] = clipField;
+ }
+ }
+
+ if ('vae' in nodeItem) {
+ const vaeField = parseVaeField(nodeItem);
+ if (vaeField) {
+ parsed[nodeKey] = vaeField;
+ }
+ }
}
// otherwise we accept any string, number or boolean
diff --git a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx
index 9527708c40..341ca19fa9 100644
--- a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx
@@ -7,6 +7,9 @@ import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
+import UNetInputFieldComponent from './fields/UNetInputFieldComponent';
+import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
+import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
import StringInputFieldComponent from './fields/StringInputFieldComponent';
@@ -97,6 +100,36 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
);
}
+ if (type === 'unet' && template.type === 'unet') {
+ return (
+
+ );
+ }
+
+ if (type === 'clip' && template.type === 'clip') {
+ return (
+
+ );
+ }
+
+ if (type === 'vae' && template.type === 'vae') {
+ return (
+
+ );
+ }
+
if (type === 'model' && template.type === 'model') {
return (
+) => {
+ const { nodeId, field } = props;
+
+ return null;
+};
+
+export default memo(ClipInputFieldComponent);
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/UNetInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/UNetInputFieldComponent.tsx
new file mode 100644
index 0000000000..5926bf113a
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/UNetInputFieldComponent.tsx
@@ -0,0 +1,16 @@
+import {
+ UNetInputFieldTemplate,
+ UNetInputFieldValue,
+} from 'features/nodes/types/types';
+import { memo } from 'react';
+import { FieldComponentProps } from './types';
+
+const UNetInputFieldComponent = (
+ props: FieldComponentProps
+) => {
+ const { nodeId, field } = props;
+
+ return null;
+};
+
+export default memo(UNetInputFieldComponent);
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/VaeInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/VaeInputFieldComponent.tsx
new file mode 100644
index 0000000000..0fa11ae34e
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/VaeInputFieldComponent.tsx
@@ -0,0 +1,16 @@
+import {
+ VaeInputFieldTemplate,
+ VaeInputFieldValue,
+} from 'features/nodes/types/types';
+import { memo } from 'react';
+import { FieldComponentProps } from './types';
+
+const VaeInputFieldComponent = (
+ props: FieldComponentProps
+) => {
+ const { nodeId, field } = props;
+
+ return null;
+};
+
+export default memo(VaeInputFieldComponent);
diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts
index 7e4dadc21d..36c3514eeb 100644
--- a/invokeai/frontend/web/src/features/nodes/types/constants.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts
@@ -11,6 +11,9 @@ export const FIELD_TYPE_MAP: Record = {
ImageField: 'image',
LatentsField: 'latents',
ConditioningField: 'conditioning',
+ UNetField: 'unet',
+ ClipField: 'clip',
+ VaeField: 'vae',
model: 'model',
array: 'array',
item: 'item',
@@ -71,6 +74,24 @@ export const FIELDS: Record = {
title: 'Conditioning',
description: 'Conditioning may be passed between nodes.',
},
+ unet: {
+ color: 'red',
+ colorCssVar: getColorTokenCssVariable('red'),
+ title: 'UNet',
+ description: 'UNet submodel.',
+ },
+ clip: {
+ color: 'green',
+ colorCssVar: getColorTokenCssVariable('green'),
+ title: 'Clip',
+ description: 'Tokenizer and text_encoder submodels.',
+ },
+ vae: {
+ color: 'blue',
+ colorCssVar: getColorTokenCssVariable('blue'),
+ title: 'Vae',
+ description: 'Vae submodel.',
+ },
model: {
color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'),
diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts
index 876ba95cac..0ed0a46964 100644
--- a/invokeai/frontend/web/src/features/nodes/types/types.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/types.ts
@@ -58,6 +58,9 @@ export type FieldType =
| 'image'
| 'latents'
| 'conditioning'
+ | 'unet'
+ | 'clip'
+ | 'vae'
| 'model'
| 'array'
| 'item'
@@ -79,6 +82,9 @@ export type InputFieldValue =
| ImageInputFieldValue
| LatentsInputFieldValue
| ConditioningInputFieldValue
+ | UNetInputFieldValue
+ | ClipInputFieldValue
+ | VaeInputFieldValue
| EnumInputFieldValue
| ModelInputFieldValue
| ArrayInputFieldValue
@@ -99,6 +105,9 @@ export type InputFieldTemplate =
| ImageInputFieldTemplate
| LatentsInputFieldTemplate
| ConditioningInputFieldTemplate
+ | UNetInputFieldTemplate
+ | ClipInputFieldTemplate
+ | VaeInputFieldTemplate
| EnumInputFieldTemplate
| ModelInputFieldTemplate
| ArrayInputFieldTemplate
@@ -177,6 +186,21 @@ export type ConditioningInputFieldValue = FieldValueBase & {
value?: undefined;
};
+export type UNetInputFieldValue = FieldValueBase & {
+ type: 'unet';
+ value?: undefined;
+};
+
+export type ClipInputFieldValue = FieldValueBase & {
+ type: 'clip';
+ value?: undefined;
+};
+
+export type VaeInputFieldValue = FieldValueBase & {
+ type: 'vae';
+ value?: undefined;
+};
+
export type ImageInputFieldValue = FieldValueBase & {
type: 'image';
value?: Pick;
diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts
index 11f0087488..b275c84248 100644
--- a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts
@@ -10,6 +10,9 @@ import {
IntegerInputFieldTemplate,
LatentsInputFieldTemplate,
ConditioningInputFieldTemplate,
+ UNetInputFieldTemplate,
+ ClipInputFieldTemplate,
+ VaeInputFieldTemplate,
StringInputFieldTemplate,
ModelInputFieldTemplate,
ArrayInputFieldTemplate,
@@ -215,6 +218,51 @@ const buildConditioningInputFieldTemplate = ({
return template;
};
+const buildUNetInputFieldTemplate = ({
+ schemaObject,
+ baseField,
+}: BuildInputFieldArg): UNetInputFieldTemplate => {
+ const template: UNetInputFieldTemplate = {
+ ...baseField,
+ type: 'unet',
+ inputRequirement: 'always',
+ inputKind: 'connection',
+ default: schemaObject.default ?? undefined,
+ };
+
+ return template;
+};
+
+const buildClipInputFieldTemplate = ({
+ schemaObject,
+ baseField,
+}: BuildInputFieldArg): ClipInputFieldTemplate => {
+ const template: ClipInputFieldTemplate = {
+ ...baseField,
+ type: 'clip',
+ inputRequirement: 'always',
+ inputKind: 'connection',
+ default: schemaObject.default ?? undefined,
+ };
+
+ return template;
+};
+
+const buildVaeInputFieldTemplate = ({
+ schemaObject,
+ baseField,
+}: BuildInputFieldArg): VaeInputFieldTemplate => {
+ const template: VaeInputFieldTemplate = {
+ ...baseField,
+ type: 'vae',
+ inputRequirement: 'always',
+ inputKind: 'connection',
+ default: schemaObject.default ?? undefined,
+ };
+
+ return template;
+};
+
const buildEnumInputFieldTemplate = ({
schemaObject,
baseField,
@@ -331,6 +379,15 @@ export const buildInputFieldTemplate = (
if (['conditioning'].includes(fieldType)) {
return buildConditioningInputFieldTemplate({ schemaObject, baseField });
}
+ if (['unet'].includes(fieldType)) {
+ return buildUNetInputFieldTemplate({ schemaObject, baseField });
+ }
+ if (['clip'].includes(fieldType)) {
+ return buildClipInputFieldTemplate({ schemaObject, baseField });
+ }
+ if (['vae'].includes(fieldType)) {
+ return buildVaeInputFieldTemplate({ schemaObject, baseField });
+ }
if (['model'].includes(fieldType)) {
return buildModelInputFieldTemplate({ schemaObject, baseField });
}
diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts
index 9221e5f7ac..c0c19708c7 100644
--- a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts
@@ -52,6 +52,18 @@ export const buildInputFieldValue = (
fieldValue.value = undefined;
}
+ if (template.type === 'unet') {
+ fieldValue.value = undefined;
+ }
+
+ if (template.type === 'clip') {
+ fieldValue.value = undefined;
+ }
+
+ if (template.type === 'vae') {
+ fieldValue.value = undefined;
+ }
+
if (template.type === 'model') {
fieldValue.value = undefined;
}