Add model loader node; unet, clip, vae fields; change compel node to clip field

This commit is contained in:
Sergey Borisov 2023-05-13 04:37:20 +03:00
parent 131145eab1
commit 3b2a054f7a
12 changed files with 466 additions and 8 deletions

View File

@ -3,6 +3,8 @@ from pydantic import BaseModel, Field
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from .model import ClipField
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
@ -41,7 +43,7 @@ class CompelInvocation(BaseInvocation):
type: Literal["compel"] = "compel"
prompt: str = Field(default="", description="Prompt")
model: str = Field(default="", description="Model to use")
clip: ClipField = Field(None, description="Clip to use")
# Schema customisation
class Config(InvocationConfig):
@ -58,12 +60,15 @@ class CompelInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> CompelOutput:
# TODO: load without model
model = context.services.model_manager.get_model(self.model)
text_encoder_info = context.services.model_manager.get_model(
self.model, SDModelType.diffusers, SDModelType.text_encoder
model_name=self.clip.text_encoder.model_name,
model_type=SDModelType[self.clip.text_encoder.model_type],
submodel=SDModelType[self.clip.text_encoder.submodel],
)
tokenizer_info = context.services.model_manager.get_model(
self.model, SDModelType.diffusers, SDModelType.tokenizer
model_name=self.clip.tokenizer.model_name,
model_type=SDModelType[self.clip.tokenizer.model_type],
submodel=SDModelType[self.clip.tokenizer.submodel],
)
with text_encoder_info.context as text_encoder,\
tokenizer_info.context as tokenizer:

View File

@ -0,0 +1,131 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel, Field
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.model_management import SDModelType
class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load unet submodel")
model_type: str = Field(description="Info to load unet submodel")
submodel: Optional[str] = Field(description="Info to load unet submodel")
class UNetField(BaseModel):
unet: ModelInfo = Field(description="Info to load unet submodel")
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
# loras: List[ModelInfo]
class ClipField(BaseModel):
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
# loras: List[ModelInfo]
class VaeField(BaseModel):
# TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel")
class ModelLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
#fmt: off
type: Literal["model_loader_output"] = "model_loader_output"
unet: UNetField = Field(default=None, description="UNet submodel")
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
vae: VaeField = Field(default=None, description="Vae submodel")
#fmt: on
class ModelLoaderInvocation(BaseInvocation):
"""Loading submodels of selected model."""
type: Literal["model_loader"] = "model_loader"
model_name: str = Field(default="", description="Model to load")
# TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["model", "loader"],
"type_hints": {
"model_name": "model" # TODO: rename to model_name?
}
},
}
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
# TODO: not found exceptions
if not context.services.model_manager.valid_model(
model_name=self.model_name,
model_type=SDModelType.diffusers,
):
raise Exception(f"Unkown model name: {self.model_name}!")
"""
if not context.services.model_manager.valid_model(
model_name=self.model_name,
model_type=SDModelType.diffusers,
submodel=SDModelType.tokenizer,
):
raise Exception(
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.valid_model(
model_name=self.model_name,
model_type=SDModelType.diffusers,
submodel=SDModelType.text_encoder,
):
raise Exception(
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.valid_model(
model_name=self.model_name,
model_type=SDModelType.diffusers,
submodel=SDModelType.unet,
):
raise Exception(
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
)
"""
return ModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=self.model_name,
model_type=SDModelType.diffusers.name,
submodel=SDModelType.unet.name,
),
scheduler=ModelInfo(
model_name=self.model_name,
model_type=SDModelType.diffusers.name,
submodel=SDModelType.scheduler.name,
),
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=self.model_name,
model_type=SDModelType.diffusers.name,
submodel=SDModelType.tokenizer.name,
),
text_encoder=ModelInfo(
model_name=self.model_name,
model_type=SDModelType.diffusers.name,
submodel=SDModelType.text_encoder.name,
),
),
vae=VaeField(
vae=ModelInfo(
model_name=self.model_name,
model_type=SDModelType.diffusers.name,
submodel=SDModelType.vae.name,
),
)
)

View File

@ -236,7 +236,7 @@ class ModelManager(object):
Given a model name, returns True if it is a valid
identifier.
"""
model_key = self.create_key(model_name, model_class)
model_key = self.create_key(model_name, model_type)
return model_key in self.config
def create_key(self, model_name: str, model_type: SDModelType) -> str:

View File

@ -1,5 +1,12 @@
import { forEach, size } from 'lodash-es';
import { ImageField, LatentsField, ConditioningField } from 'services/api';
import {
ImageField,
LatentsField,
ConditioningField,
UNetField,
ClipField,
VaeField,
} from 'services/api';
const OBJECT_TYPESTRING = '[object Object]';
const STRING_TYPESTRING = '[object String]';
@ -98,6 +105,101 @@ const parseConditioningField = (
};
};
const _parseModelInfo = (modelInfo: unknown): ModelInfo | undefined => {
// Must be an object
if (!isObject(modelInfo)) {
return;
}
if (!('model_name' in modelInfo && typeof modelInfo.model_name == 'string')) {
return;
}
if (!('model_type' in modelInfo && typeof modelInfo.model_type == 'string')) {
return;
}
if (!('submodel' in modelInfo && typeof modelInfo.submodel == 'string')) {
return;
}
return {
model_name: modelInfo.model_name,
model_type: modelInfo.model_type,
submodel: modelInfo.submodel,
};
};
const parseUNetField = (unetField: unknown): UNetField | undefined => {
// Must be an object
if (!isObject(unetField)) {
return;
}
if (!('unet' in unetField && 'scheduler' in unetField)) {
return;
}
const unet = _parseModelInfo(unetField.unet);
const scheduler = _parseModelInfo(unetField.scheduler);
if (!(unet && scheduler)) {
return;
}
// Build a valid UNetField
return {
unet: unet,
scheduler: scheduler,
};
};
const parseClipField = (clipField: unknown): ClipField | undefined => {
// Must be an object
if (!isObject(clipField)) {
return;
}
if (!('tokenizer' in clipField && 'text_encoder' in clipField)) {
return;
}
const tokenizer = _parseModelInfo(clipField.tokenizer);
const text_encoder = _parseModelInfo(clipField.text_encoder);
if (!(tokenizer && text_encoder)) {
return;
}
// Build a valid ClipField
return {
tokenizer: tokenizer,
text_encoder: text_encoder,
};
};
const parseVaeField = (vaeField: unknown): VaeField | undefined => {
// Must be an object
if (!isObject(vaeField)) {
return;
}
if (!('vae' in vaeField)) {
return;
}
const vae = _parseModelInfo(vaeField.vae);
if (!vae) {
return;
}
// Build a valid VaeField
return {
vae: vae,
};
};
type NodeMetadata = {
[key: string]:
| string
@ -105,7 +207,10 @@ type NodeMetadata = {
| boolean
| ImageField
| LatentsField
| ConditioningField;
| ConditioningField
| UNetField
| ClipField
| VaeField;
};
type InvokeAIMetadata = {
@ -131,7 +236,8 @@ export const parseNodeMetadata = (
return;
}
// the only valid object types are ImageField, LatentsField and ConditioningField
// valid object types are:
// ImageField, LatentsField ConditioningField, UNetField, ClipField, VaeField
if (isObject(nodeItem)) {
if ('image_name' in nodeItem || 'image_type' in nodeItem) {
const imageField = parseImageField(nodeItem);
@ -156,6 +262,27 @@ export const parseNodeMetadata = (
}
return;
}
if ('unet' in nodeItem && 'tokenizer' in nodeItem) {
const unetField = parseUNetField(nodeItem);
if (unetField) {
parsed[nodeKey] = unetField;
}
}
if ('tokenizer' in nodeItem && 'text_encoder' in nodeItem) {
const clipField = parseClipField(nodeItem);
if (clipField) {
parsed[nodeKey] = clipField;
}
}
if ('vae' in nodeItem) {
const vaeField = parseVaeField(nodeItem);
if (vaeField) {
parsed[nodeKey] = vaeField;
}
}
}
// otherwise we accept any string, number or boolean

View File

@ -7,6 +7,9 @@ import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
import UNetInputFieldComponent from './fields/UNetInputFieldComponent';
import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
import StringInputFieldComponent from './fields/StringInputFieldComponent';
@ -97,6 +100,36 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
);
}
if (type === 'unet' && template.type === 'unet') {
return (
<UNetInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'clip' && template.type === 'clip') {
return (
<ClipInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'vae' && template.type === 'vae') {
return (
<VaeInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'model' && template.type === 'model') {
return (
<ModelInputFieldComponent

View File

@ -0,0 +1,16 @@
import {
ClipInputFieldTemplate,
ClipInputFieldValue,
} from 'features/nodes/types/types';
import { memo } from 'react';
import { FieldComponentProps } from './types';
const ClipInputFieldComponent = (
props: FieldComponentProps<ClipInputFieldValue, ClipInputFieldTemplate>
) => {
const { nodeId, field } = props;
return null;
};
export default memo(ClipInputFieldComponent);

View File

@ -0,0 +1,16 @@
import {
UNetInputFieldTemplate,
UNetInputFieldValue,
} from 'features/nodes/types/types';
import { memo } from 'react';
import { FieldComponentProps } from './types';
const UNetInputFieldComponent = (
props: FieldComponentProps<UNetInputFieldValue, UNetInputFieldTemplate>
) => {
const { nodeId, field } = props;
return null;
};
export default memo(UNetInputFieldComponent);

View File

@ -0,0 +1,16 @@
import {
VaeInputFieldTemplate,
VaeInputFieldValue,
} from 'features/nodes/types/types';
import { memo } from 'react';
import { FieldComponentProps } from './types';
const VaeInputFieldComponent = (
props: FieldComponentProps<VaeInputFieldValue, VaeInputFieldTemplate>
) => {
const { nodeId, field } = props;
return null;
};
export default memo(VaeInputFieldComponent);

View File

@ -11,6 +11,9 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
ImageField: 'image',
LatentsField: 'latents',
ConditioningField: 'conditioning',
UNetField: 'unet',
ClipField: 'clip',
VaeField: 'vae',
model: 'model',
array: 'array',
item: 'item',
@ -71,6 +74,24 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
title: 'Conditioning',
description: 'Conditioning may be passed between nodes.',
},
unet: {
color: 'red',
colorCssVar: getColorTokenCssVariable('red'),
title: 'UNet',
description: 'UNet submodel.',
},
clip: {
color: 'green',
colorCssVar: getColorTokenCssVariable('green'),
title: 'Clip',
description: 'Tokenizer and text_encoder submodels.',
},
vae: {
color: 'blue',
colorCssVar: getColorTokenCssVariable('blue'),
title: 'Vae',
description: 'Vae submodel.',
},
model: {
color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'),

View File

@ -58,6 +58,9 @@ export type FieldType =
| 'image'
| 'latents'
| 'conditioning'
| 'unet'
| 'clip'
| 'vae'
| 'model'
| 'array'
| 'item'
@ -79,6 +82,9 @@ export type InputFieldValue =
| ImageInputFieldValue
| LatentsInputFieldValue
| ConditioningInputFieldValue
| UNetInputFieldValue
| ClipInputFieldValue
| VaeInputFieldValue
| EnumInputFieldValue
| ModelInputFieldValue
| ArrayInputFieldValue
@ -99,6 +105,9 @@ export type InputFieldTemplate =
| ImageInputFieldTemplate
| LatentsInputFieldTemplate
| ConditioningInputFieldTemplate
| UNetInputFieldTemplate
| ClipInputFieldTemplate
| VaeInputFieldTemplate
| EnumInputFieldTemplate
| ModelInputFieldTemplate
| ArrayInputFieldTemplate
@ -177,6 +186,21 @@ export type ConditioningInputFieldValue = FieldValueBase & {
value?: undefined;
};
export type UNetInputFieldValue = FieldValueBase & {
type: 'unet';
value?: undefined;
};
export type ClipInputFieldValue = FieldValueBase & {
type: 'clip';
value?: undefined;
};
export type VaeInputFieldValue = FieldValueBase & {
type: 'vae';
value?: undefined;
};
export type ImageInputFieldValue = FieldValueBase & {
type: 'image';
value?: Pick<ImageField, 'image_name' | 'image_type'>;

View File

@ -10,6 +10,9 @@ import {
IntegerInputFieldTemplate,
LatentsInputFieldTemplate,
ConditioningInputFieldTemplate,
UNetInputFieldTemplate,
ClipInputFieldTemplate,
VaeInputFieldTemplate,
StringInputFieldTemplate,
ModelInputFieldTemplate,
ArrayInputFieldTemplate,
@ -215,6 +218,51 @@ const buildConditioningInputFieldTemplate = ({
return template;
};
const buildUNetInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): UNetInputFieldTemplate => {
const template: UNetInputFieldTemplate = {
...baseField,
type: 'unet',
inputRequirement: 'always',
inputKind: 'connection',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildClipInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ClipInputFieldTemplate => {
const template: ClipInputFieldTemplate = {
...baseField,
type: 'clip',
inputRequirement: 'always',
inputKind: 'connection',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildVaeInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): VaeInputFieldTemplate => {
const template: VaeInputFieldTemplate = {
...baseField,
type: 'vae',
inputRequirement: 'always',
inputKind: 'connection',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildEnumInputFieldTemplate = ({
schemaObject,
baseField,
@ -331,6 +379,15 @@ export const buildInputFieldTemplate = (
if (['conditioning'].includes(fieldType)) {
return buildConditioningInputFieldTemplate({ schemaObject, baseField });
}
if (['unet'].includes(fieldType)) {
return buildUNetInputFieldTemplate({ schemaObject, baseField });
}
if (['clip'].includes(fieldType)) {
return buildClipInputFieldTemplate({ schemaObject, baseField });
}
if (['vae'].includes(fieldType)) {
return buildVaeInputFieldTemplate({ schemaObject, baseField });
}
if (['model'].includes(fieldType)) {
return buildModelInputFieldTemplate({ schemaObject, baseField });
}

View File

@ -52,6 +52,18 @@ export const buildInputFieldValue = (
fieldValue.value = undefined;
}
if (template.type === 'unet') {
fieldValue.value = undefined;
}
if (template.type === 'clip') {
fieldValue.value = undefined;
}
if (template.type === 'vae') {
fieldValue.value = undefined;
}
if (template.type === 'model') {
fieldValue.value = undefined;
}