diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 25bb421544..584e14ea0f 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -3,6 +3,8 @@ from pydantic import BaseModel, Field from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig +from .model import ClipField + from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager @@ -41,7 +43,7 @@ class CompelInvocation(BaseInvocation): type: Literal["compel"] = "compel" prompt: str = Field(default="", description="Prompt") - model: str = Field(default="", description="Model to use") + clip: ClipField = Field(None, description="Clip to use") # Schema customisation class Config(InvocationConfig): @@ -58,12 +60,15 @@ class CompelInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> CompelOutput: # TODO: load without model - model = context.services.model_manager.get_model(self.model) text_encoder_info = context.services.model_manager.get_model( - self.model, SDModelType.diffusers, SDModelType.text_encoder + model_name=self.clip.text_encoder.model_name, + model_type=SDModelType[self.clip.text_encoder.model_type], + submodel=SDModelType[self.clip.text_encoder.submodel], ) tokenizer_info = context.services.model_manager.get_model( - self.model, SDModelType.diffusers, SDModelType.tokenizer + model_name=self.clip.tokenizer.model_name, + model_type=SDModelType[self.clip.tokenizer.model_type], + submodel=SDModelType[self.clip.tokenizer.submodel], ) with text_encoder_info.context as text_encoder,\ tokenizer_info.context as tokenizer: diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py new file mode 100644 index 0000000000..11f1fd5c5f --- /dev/null +++ b/invokeai/app/invocations/model.py @@ -0,0 +1,131 @@ +from typing import Literal, Optional, Union +from pydantic import BaseModel, Field + +from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig + +from ...backend.util.devices import choose_torch_device, torch_dtype +from ...backend.model_management import SDModelType + +class ModelInfo(BaseModel): + model_name: str = Field(description="Info to load unet submodel") + model_type: str = Field(description="Info to load unet submodel") + submodel: Optional[str] = Field(description="Info to load unet submodel") + +class UNetField(BaseModel): + unet: ModelInfo = Field(description="Info to load unet submodel") + scheduler: ModelInfo = Field(description="Info to load scheduler submodel") + # loras: List[ModelInfo] + +class ClipField(BaseModel): + tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel") + text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel") + # loras: List[ModelInfo] + +class VaeField(BaseModel): + # TODO: better naming? + vae: ModelInfo = Field(description="Info to load vae submodel") + + +class ModelLoaderOutput(BaseInvocationOutput): + """Model loader output""" + + #fmt: off + type: Literal["model_loader_output"] = "model_loader_output" + + unet: UNetField = Field(default=None, description="UNet submodel") + clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") + vae: VaeField = Field(default=None, description="Vae submodel") + #fmt: on + + +class ModelLoaderInvocation(BaseInvocation): + """Loading submodels of selected model.""" + + type: Literal["model_loader"] = "model_loader" + + model_name: str = Field(default="", description="Model to load") + # TODO: precision? + + # Schema customisation + class Config(InvocationConfig): + schema_extra = { + "ui": { + "tags": ["model", "loader"], + "type_hints": { + "model_name": "model" # TODO: rename to model_name? + } + }, + } + + def invoke(self, context: InvocationContext) -> ModelLoaderOutput: + + # TODO: not found exceptions + if not context.services.model_manager.valid_model( + model_name=self.model_name, + model_type=SDModelType.diffusers, + ): + raise Exception(f"Unkown model name: {self.model_name}!") + + """ + if not context.services.model_manager.valid_model( + model_name=self.model_name, + model_type=SDModelType.diffusers, + submodel=SDModelType.tokenizer, + ): + raise Exception( + f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted" + ) + + if not context.services.model_manager.valid_model( + model_name=self.model_name, + model_type=SDModelType.diffusers, + submodel=SDModelType.text_encoder, + ): + raise Exception( + f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted" + ) + + if not context.services.model_manager.valid_model( + model_name=self.model_name, + model_type=SDModelType.diffusers, + submodel=SDModelType.unet, + ): + raise Exception( + f"Failed to find unet submodel from {self.model_name}! Check if model corrupted" + ) + """ + + + return ModelLoaderOutput( + unet=UNetField( + unet=ModelInfo( + model_name=self.model_name, + model_type=SDModelType.diffusers.name, + submodel=SDModelType.unet.name, + ), + scheduler=ModelInfo( + model_name=self.model_name, + model_type=SDModelType.diffusers.name, + submodel=SDModelType.scheduler.name, + ), + ), + clip=ClipField( + tokenizer=ModelInfo( + model_name=self.model_name, + model_type=SDModelType.diffusers.name, + submodel=SDModelType.tokenizer.name, + ), + text_encoder=ModelInfo( + model_name=self.model_name, + model_type=SDModelType.diffusers.name, + submodel=SDModelType.text_encoder.name, + ), + ), + vae=VaeField( + vae=ModelInfo( + model_name=self.model_name, + model_type=SDModelType.diffusers.name, + submodel=SDModelType.vae.name, + ), + ) + ) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index bd70e23f5b..e32ba3b22a 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -236,7 +236,7 @@ class ModelManager(object): Given a model name, returns True if it is a valid identifier. """ - model_key = self.create_key(model_name, model_class) + model_key = self.create_key(model_name, model_type) return model_key in self.config def create_key(self, model_name: str, model_type: SDModelType) -> str: diff --git a/invokeai/frontend/web/src/common/util/parseMetadata.ts b/invokeai/frontend/web/src/common/util/parseMetadata.ts index c27833218b..95d08db3e0 100644 --- a/invokeai/frontend/web/src/common/util/parseMetadata.ts +++ b/invokeai/frontend/web/src/common/util/parseMetadata.ts @@ -1,5 +1,12 @@ import { forEach, size } from 'lodash-es'; -import { ImageField, LatentsField, ConditioningField } from 'services/api'; +import { + ImageField, + LatentsField, + ConditioningField, + UNetField, + ClipField, + VaeField, +} from 'services/api'; const OBJECT_TYPESTRING = '[object Object]'; const STRING_TYPESTRING = '[object String]'; @@ -98,6 +105,101 @@ const parseConditioningField = ( }; }; +const _parseModelInfo = (modelInfo: unknown): ModelInfo | undefined => { + // Must be an object + if (!isObject(modelInfo)) { + return; + } + + if (!('model_name' in modelInfo && typeof modelInfo.model_name == 'string')) { + return; + } + + if (!('model_type' in modelInfo && typeof modelInfo.model_type == 'string')) { + return; + } + + if (!('submodel' in modelInfo && typeof modelInfo.submodel == 'string')) { + return; + } + + return { + model_name: modelInfo.model_name, + model_type: modelInfo.model_type, + submodel: modelInfo.submodel, + }; +}; + +const parseUNetField = (unetField: unknown): UNetField | undefined => { + // Must be an object + if (!isObject(unetField)) { + return; + } + + if (!('unet' in unetField && 'scheduler' in unetField)) { + return; + } + + const unet = _parseModelInfo(unetField.unet); + const scheduler = _parseModelInfo(unetField.scheduler); + + if (!(unet && scheduler)) { + return; + } + + // Build a valid UNetField + return { + unet: unet, + scheduler: scheduler, + }; +}; + +const parseClipField = (clipField: unknown): ClipField | undefined => { + // Must be an object + if (!isObject(clipField)) { + return; + } + + if (!('tokenizer' in clipField && 'text_encoder' in clipField)) { + return; + } + + const tokenizer = _parseModelInfo(clipField.tokenizer); + const text_encoder = _parseModelInfo(clipField.text_encoder); + + if (!(tokenizer && text_encoder)) { + return; + } + + // Build a valid ClipField + return { + tokenizer: tokenizer, + text_encoder: text_encoder, + }; +}; + +const parseVaeField = (vaeField: unknown): VaeField | undefined => { + // Must be an object + if (!isObject(vaeField)) { + return; + } + + if (!('vae' in vaeField)) { + return; + } + + const vae = _parseModelInfo(vaeField.vae); + + if (!vae) { + return; + } + + // Build a valid VaeField + return { + vae: vae, + }; +}; + type NodeMetadata = { [key: string]: | string @@ -105,7 +207,10 @@ type NodeMetadata = { | boolean | ImageField | LatentsField - | ConditioningField; + | ConditioningField + | UNetField + | ClipField + | VaeField; }; type InvokeAIMetadata = { @@ -131,7 +236,8 @@ export const parseNodeMetadata = ( return; } - // the only valid object types are ImageField, LatentsField and ConditioningField + // valid object types are: + // ImageField, LatentsField ConditioningField, UNetField, ClipField, VaeField if (isObject(nodeItem)) { if ('image_name' in nodeItem || 'image_type' in nodeItem) { const imageField = parseImageField(nodeItem); @@ -156,6 +262,27 @@ export const parseNodeMetadata = ( } return; } + + if ('unet' in nodeItem && 'tokenizer' in nodeItem) { + const unetField = parseUNetField(nodeItem); + if (unetField) { + parsed[nodeKey] = unetField; + } + } + + if ('tokenizer' in nodeItem && 'text_encoder' in nodeItem) { + const clipField = parseClipField(nodeItem); + if (clipField) { + parsed[nodeKey] = clipField; + } + } + + if ('vae' in nodeItem) { + const vaeField = parseVaeField(nodeItem); + if (vaeField) { + parsed[nodeKey] = vaeField; + } + } } // otherwise we accept any string, number or boolean diff --git a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx index 9527708c40..341ca19fa9 100644 --- a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx @@ -7,6 +7,9 @@ import EnumInputFieldComponent from './fields/EnumInputFieldComponent'; import ImageInputFieldComponent from './fields/ImageInputFieldComponent'; import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent'; import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent'; +import UNetInputFieldComponent from './fields/UNetInputFieldComponent'; +import ClipInputFieldComponent from './fields/ClipInputFieldComponent'; +import VaeInputFieldComponent from './fields/VaeInputFieldComponent'; import ModelInputFieldComponent from './fields/ModelInputFieldComponent'; import NumberInputFieldComponent from './fields/NumberInputFieldComponent'; import StringInputFieldComponent from './fields/StringInputFieldComponent'; @@ -97,6 +100,36 @@ const InputFieldComponent = (props: InputFieldComponentProps) => { ); } + if (type === 'unet' && template.type === 'unet') { + return ( + + ); + } + + if (type === 'clip' && template.type === 'clip') { + return ( + + ); + } + + if (type === 'vae' && template.type === 'vae') { + return ( + + ); + } + if (type === 'model' && template.type === 'model') { return ( +) => { + const { nodeId, field } = props; + + return null; +}; + +export default memo(ClipInputFieldComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/UNetInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/UNetInputFieldComponent.tsx new file mode 100644 index 0000000000..5926bf113a --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/fields/UNetInputFieldComponent.tsx @@ -0,0 +1,16 @@ +import { + UNetInputFieldTemplate, + UNetInputFieldValue, +} from 'features/nodes/types/types'; +import { memo } from 'react'; +import { FieldComponentProps } from './types'; + +const UNetInputFieldComponent = ( + props: FieldComponentProps +) => { + const { nodeId, field } = props; + + return null; +}; + +export default memo(UNetInputFieldComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/VaeInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/VaeInputFieldComponent.tsx new file mode 100644 index 0000000000..0fa11ae34e --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/fields/VaeInputFieldComponent.tsx @@ -0,0 +1,16 @@ +import { + VaeInputFieldTemplate, + VaeInputFieldValue, +} from 'features/nodes/types/types'; +import { memo } from 'react'; +import { FieldComponentProps } from './types'; + +const VaeInputFieldComponent = ( + props: FieldComponentProps +) => { + const { nodeId, field } = props; + + return null; +}; + +export default memo(VaeInputFieldComponent); diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index 7e4dadc21d..36c3514eeb 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -11,6 +11,9 @@ export const FIELD_TYPE_MAP: Record = { ImageField: 'image', LatentsField: 'latents', ConditioningField: 'conditioning', + UNetField: 'unet', + ClipField: 'clip', + VaeField: 'vae', model: 'model', array: 'array', item: 'item', @@ -71,6 +74,24 @@ export const FIELDS: Record = { title: 'Conditioning', description: 'Conditioning may be passed between nodes.', }, + unet: { + color: 'red', + colorCssVar: getColorTokenCssVariable('red'), + title: 'UNet', + description: 'UNet submodel.', + }, + clip: { + color: 'green', + colorCssVar: getColorTokenCssVariable('green'), + title: 'Clip', + description: 'Tokenizer and text_encoder submodels.', + }, + vae: { + color: 'blue', + colorCssVar: getColorTokenCssVariable('blue'), + title: 'Vae', + description: 'Vae submodel.', + }, model: { color: 'teal', colorCssVar: getColorTokenCssVariable('teal'), diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index 876ba95cac..0ed0a46964 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -58,6 +58,9 @@ export type FieldType = | 'image' | 'latents' | 'conditioning' + | 'unet' + | 'clip' + | 'vae' | 'model' | 'array' | 'item' @@ -79,6 +82,9 @@ export type InputFieldValue = | ImageInputFieldValue | LatentsInputFieldValue | ConditioningInputFieldValue + | UNetInputFieldValue + | ClipInputFieldValue + | VaeInputFieldValue | EnumInputFieldValue | ModelInputFieldValue | ArrayInputFieldValue @@ -99,6 +105,9 @@ export type InputFieldTemplate = | ImageInputFieldTemplate | LatentsInputFieldTemplate | ConditioningInputFieldTemplate + | UNetInputFieldTemplate + | ClipInputFieldTemplate + | VaeInputFieldTemplate | EnumInputFieldTemplate | ModelInputFieldTemplate | ArrayInputFieldTemplate @@ -177,6 +186,21 @@ export type ConditioningInputFieldValue = FieldValueBase & { value?: undefined; }; +export type UNetInputFieldValue = FieldValueBase & { + type: 'unet'; + value?: undefined; +}; + +export type ClipInputFieldValue = FieldValueBase & { + type: 'clip'; + value?: undefined; +}; + +export type VaeInputFieldValue = FieldValueBase & { + type: 'vae'; + value?: undefined; +}; + export type ImageInputFieldValue = FieldValueBase & { type: 'image'; value?: Pick; diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts index 11f0087488..b275c84248 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts @@ -10,6 +10,9 @@ import { IntegerInputFieldTemplate, LatentsInputFieldTemplate, ConditioningInputFieldTemplate, + UNetInputFieldTemplate, + ClipInputFieldTemplate, + VaeInputFieldTemplate, StringInputFieldTemplate, ModelInputFieldTemplate, ArrayInputFieldTemplate, @@ -215,6 +218,51 @@ const buildConditioningInputFieldTemplate = ({ return template; }; +const buildUNetInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): UNetInputFieldTemplate => { + const template: UNetInputFieldTemplate = { + ...baseField, + type: 'unet', + inputRequirement: 'always', + inputKind: 'connection', + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildClipInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): ClipInputFieldTemplate => { + const template: ClipInputFieldTemplate = { + ...baseField, + type: 'clip', + inputRequirement: 'always', + inputKind: 'connection', + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildVaeInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): VaeInputFieldTemplate => { + const template: VaeInputFieldTemplate = { + ...baseField, + type: 'vae', + inputRequirement: 'always', + inputKind: 'connection', + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildEnumInputFieldTemplate = ({ schemaObject, baseField, @@ -331,6 +379,15 @@ export const buildInputFieldTemplate = ( if (['conditioning'].includes(fieldType)) { return buildConditioningInputFieldTemplate({ schemaObject, baseField }); } + if (['unet'].includes(fieldType)) { + return buildUNetInputFieldTemplate({ schemaObject, baseField }); + } + if (['clip'].includes(fieldType)) { + return buildClipInputFieldTemplate({ schemaObject, baseField }); + } + if (['vae'].includes(fieldType)) { + return buildVaeInputFieldTemplate({ schemaObject, baseField }); + } if (['model'].includes(fieldType)) { return buildModelInputFieldTemplate({ schemaObject, baseField }); } diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts index 9221e5f7ac..c0c19708c7 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts @@ -52,6 +52,18 @@ export const buildInputFieldValue = ( fieldValue.value = undefined; } + if (template.type === 'unet') { + fieldValue.value = undefined; + } + + if (template.type === 'clip') { + fieldValue.value = undefined; + } + + if (template.type === 'vae') { + fieldValue.value = undefined; + } + if (template.type === 'model') { fieldValue.value = undefined; }