mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 23:58:14 +00:00
WIP - model selection for LLaVA
This commit is contained in:
@ -59,6 +59,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
ControlLoRAModel = "ControlLoRAModelField"
|
||||
SigLipModel = "SigLipModelField"
|
||||
FluxReduxModel = "FluxReduxModelField"
|
||||
LlavaOnevisionModel = "LLaVAModelField"
|
||||
# endregion
|
||||
|
||||
# region Misc Field Types
|
||||
@ -205,6 +206,7 @@ class FieldDescriptions:
|
||||
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."
|
||||
instantx_control_mode = "The control mode for InstantX ControlNet union models. Ignored for other ControlNet models. The standard mapping is: canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6). Negative values will be treated as 'None'."
|
||||
flux_redux_conditioning = "FLUX Redux conditioning tensor"
|
||||
vllm_model = "The VLLM model to use"
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
|
@ -6,6 +6,8 @@ from invokeai.app.invocations.fields import ImageField, InputField, UIComponent
|
||||
from invokeai.app.invocations.primitives import StringOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, UIType
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@ -20,11 +22,11 @@ class LlavaOnevisionVllmInvocation(BaseInvocation):
|
||||
description="Input text prompt.",
|
||||
ui_component=UIComponent.Textarea,
|
||||
)
|
||||
# vllm_model: ModelIdentifierField = InputField(
|
||||
# title="Image-to-Image Model",
|
||||
# description=FieldDescriptions.vllm_model,
|
||||
# ui_type=UIType.LlavaOnevisionModel,
|
||||
# )
|
||||
vllm_model: ModelIdentifierField = InputField(
|
||||
title="Image-to-Image Model",
|
||||
description=FieldDescriptions.vllm_model,
|
||||
ui_type=UIType.LlavaOnevisionModel,
|
||||
)
|
||||
|
||||
def _get_images(self, context: InvocationContext) -> list[Image]:
|
||||
if self.images is None:
|
||||
|
@ -32,6 +32,8 @@ import {
|
||||
isColorFieldInputTemplate,
|
||||
isControlLoRAModelFieldInputInstance,
|
||||
isControlLoRAModelFieldInputTemplate,
|
||||
isLLaVAModelFieldInputInstance,
|
||||
isLLaVAModelFieldInputTemplate,
|
||||
isControlNetModelFieldInputInstance,
|
||||
isControlNetModelFieldInputTemplate,
|
||||
isEnumFieldInputInstance,
|
||||
@ -105,6 +107,7 @@ import CLIPGEmbedModelFieldInputComponent from './inputs/CLIPGEmbedModelFieldInp
|
||||
import CLIPLEmbedModelFieldInputComponent from './inputs/CLIPLEmbedModelFieldInputComponent';
|
||||
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
|
||||
import ControlLoRAModelFieldInputComponent from './inputs/ControlLoraModelFieldInputComponent';
|
||||
import LLaVAModelFieldInputComponent from './inputs/LLaVAModelFieldInputComponent';
|
||||
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
|
||||
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
|
||||
import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent';
|
||||
@ -322,6 +325,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
|
||||
return <ControlLoRAModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isLLaVAModelFieldInputTemplate(template)) {
|
||||
if (!isLLaVAModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <LLaVAModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isFluxVAEModelFieldInputTemplate(template)) {
|
||||
if (!isFluxVAEModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
|
@ -0,0 +1,55 @@
|
||||
import { Combobox, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldLLaVAModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { LLaVAModelFieldInputInstance, LLaVAModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useLLaVAModels } from 'services/api/hooks/modelsByType';
|
||||
import type { LlavaOnevisionConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<LLaVAModelFieldInputInstance, LLaVAModelFieldInputTemplate>;
|
||||
|
||||
const LLaVAModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useLLaVAModels();
|
||||
const _onChange = useCallback(
|
||||
(value: LlavaOnevisionConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldLLaVAModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: field.value,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
return (
|
||||
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value} isDisabled={!options.length}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(LLaVAModelFieldInputComponent);
|
@ -28,6 +28,7 @@ import type {
|
||||
IntegerGeneratorFieldValue,
|
||||
IPAdapterModelFieldValue,
|
||||
LoRAModelFieldValue,
|
||||
LLaVAModelFieldValue,
|
||||
MainModelFieldValue,
|
||||
ModelIdentifierFieldValue,
|
||||
SchedulerFieldValue,
|
||||
@ -65,6 +66,7 @@ import {
|
||||
zIntegerGeneratorFieldValue,
|
||||
zIPAdapterModelFieldValue,
|
||||
zLoRAModelFieldValue,
|
||||
zLLaVAModelFieldValue,
|
||||
zMainModelFieldValue,
|
||||
zModelIdentifierFieldValue,
|
||||
zSchedulerFieldValue,
|
||||
@ -380,6 +382,9 @@ export const nodesSlice = createSlice({
|
||||
fieldLoRAModelValueChanged: (state, action: FieldValueAction<LoRAModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zLoRAModelFieldValue);
|
||||
},
|
||||
fieldLLaVAModelValueChanged: (state, action: FieldValueAction<LLaVAModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zLLaVAModelFieldValue);
|
||||
},
|
||||
fieldControlNetModelValueChanged: (state, action: FieldValueAction<ControlNetModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zControlNetModelFieldValue);
|
||||
},
|
||||
@ -509,6 +514,7 @@ export const {
|
||||
fieldSpandrelImageToImageModelValueChanged,
|
||||
fieldLabelChanged,
|
||||
fieldLoRAModelValueChanged,
|
||||
fieldLLaVAModelValueChanged,
|
||||
fieldModelIdentifierValueChanged,
|
||||
fieldMainModelValueChanged,
|
||||
fieldIntegerValueChanged,
|
||||
@ -633,6 +639,7 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
|
||||
fieldT2IAdapterModelValueChanged,
|
||||
fieldLabelChanged,
|
||||
fieldLoRAModelValueChanged,
|
||||
fieldLLaVAModelValueChanged,
|
||||
fieldMainModelValueChanged,
|
||||
fieldIntegerValueChanged,
|
||||
fieldIntegerCollectionValueChanged,
|
||||
|
@ -69,6 +69,7 @@ const zModelType = z.enum([
|
||||
'main',
|
||||
'vae',
|
||||
'lora',
|
||||
"llava_onevision",
|
||||
'control_lora',
|
||||
'controlnet',
|
||||
't2i_adapter',
|
||||
|
@ -189,6 +189,10 @@ const zLoRAModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('LoRAModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zLLaVAModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('LLaVAModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zControlNetModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ControlNetModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@ -273,6 +277,7 @@ const zStatefulFieldType = z.union([
|
||||
zSDXLRefinerModelFieldType,
|
||||
zVAEModelFieldType,
|
||||
zLoRAModelFieldType,
|
||||
zLLaVAModelFieldType,
|
||||
zControlNetModelFieldType,
|
||||
zIPAdapterModelFieldType,
|
||||
zT2IAdapterModelFieldType,
|
||||
@ -309,6 +314,7 @@ const modelFieldTypeNames = [
|
||||
zSDXLRefinerModelFieldType.shape.name.value,
|
||||
zVAEModelFieldType.shape.name.value,
|
||||
zLoRAModelFieldType.shape.name.value,
|
||||
zLLaVAModelFieldType.shape.name.value,
|
||||
zControlNetModelFieldType.shape.name.value,
|
||||
zIPAdapterModelFieldType.shape.name.value,
|
||||
zT2IAdapterModelFieldType.shape.name.value,
|
||||
@ -891,6 +897,27 @@ export const isLoRAModelFieldInputInstance = buildInstanceTypeGuard(zLoRAModelFi
|
||||
export const isLoRAModelFieldInputTemplate = buildTemplateTypeGuard<LoRAModelFieldInputTemplate>('LoRAModelField');
|
||||
// #endregion
|
||||
|
||||
// #region LLaVAModelField
|
||||
export const zLLaVAModelFieldValue = zModelIdentifierField.optional();
|
||||
const zLLaVAModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zLLaVAModelFieldValue,
|
||||
});
|
||||
const zLLaVAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zLLaVAModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zLLaVAModelFieldValue,
|
||||
});
|
||||
const zLLaVAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zLLaVAModelFieldType,
|
||||
});
|
||||
export type LLaVAModelFieldValue = z.infer<typeof zLLaVAModelFieldValue>;
|
||||
export type LLaVAModelFieldInputInstance = z.infer<typeof zLLaVAModelFieldInputInstance>;
|
||||
export type LLaVAModelFieldInputTemplate = z.infer<typeof zLLaVAModelFieldInputTemplate>;
|
||||
export const isLLaVAModelFieldInputInstance = buildInstanceTypeGuard(zLLaVAModelFieldInputInstance);
|
||||
export const isLLaVAModelFieldInputTemplate = buildTemplateTypeGuard<LLaVAModelFieldInputTemplate>('LLaVAModelField');
|
||||
// #endregion
|
||||
|
||||
|
||||
// #region ControlNetModelField
|
||||
export const zControlNetModelFieldValue = zModelIdentifierField.optional();
|
||||
const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
@ -1739,6 +1766,7 @@ export const zStatefulFieldValue = z.union([
|
||||
zSDXLRefinerModelFieldValue,
|
||||
zVAEModelFieldValue,
|
||||
zLoRAModelFieldValue,
|
||||
zLLaVAModelFieldValue,
|
||||
zControlNetModelFieldValue,
|
||||
zIPAdapterModelFieldValue,
|
||||
zT2IAdapterModelFieldValue,
|
||||
@ -1785,6 +1813,7 @@ const zStatefulFieldInputInstance = z.union([
|
||||
zSDXLRefinerModelFieldInputInstance,
|
||||
zVAEModelFieldInputInstance,
|
||||
zLoRAModelFieldInputInstance,
|
||||
zLLaVAModelFieldInputInstance,
|
||||
zControlNetModelFieldInputInstance,
|
||||
zIPAdapterModelFieldInputInstance,
|
||||
zT2IAdapterModelFieldInputInstance,
|
||||
@ -1825,6 +1854,7 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
zSDXLRefinerModelFieldInputTemplate,
|
||||
zVAEModelFieldInputTemplate,
|
||||
zLoRAModelFieldInputTemplate,
|
||||
zLLaVAModelFieldInputTemplate,
|
||||
zControlNetModelFieldInputTemplate,
|
||||
zIPAdapterModelFieldInputTemplate,
|
||||
zT2IAdapterModelFieldInputTemplate,
|
||||
@ -1871,6 +1901,7 @@ const zStatefulFieldOutputTemplate = z.union([
|
||||
zSDXLRefinerModelFieldOutputTemplate,
|
||||
zVAEModelFieldOutputTemplate,
|
||||
zLoRAModelFieldOutputTemplate,
|
||||
zLLaVAModelFieldOutputTemplate,
|
||||
zControlNetModelFieldOutputTemplate,
|
||||
zIPAdapterModelFieldOutputTemplate,
|
||||
zT2IAdapterModelFieldOutputTemplate,
|
||||
|
@ -11,6 +11,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
||||
IntegerField: 0,
|
||||
IPAdapterModelField: undefined,
|
||||
LoRAModelField: undefined,
|
||||
LLaVAModelField: undefined,
|
||||
ModelIdentifierField: undefined,
|
||||
MainModelField: undefined,
|
||||
SchedulerField: 'dpmpp_3m_k',
|
||||
|
@ -24,6 +24,7 @@ import type {
|
||||
IntegerFieldInputTemplate,
|
||||
IntegerGeneratorFieldInputTemplate,
|
||||
IPAdapterModelFieldInputTemplate,
|
||||
LLaVAModelFieldInputTemplate,
|
||||
LoRAModelFieldInputTemplate,
|
||||
MainModelFieldInputTemplate,
|
||||
ModelIdentifierFieldInputTemplate,
|
||||
@ -448,6 +449,19 @@ const buildControlLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<Control
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildLLaVAModelFieldInputTemplate: FieldInputTemplateBuilder<LLaVAModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: LLaVAModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFluxVAEModelFieldInputTemplate: FieldInputTemplateBuilder<FluxVAEModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -741,6 +755,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
||||
IntegerField: buildIntegerFieldInputTemplate,
|
||||
IPAdapterModelField: buildIPAdapterModelFieldInputTemplate,
|
||||
LoRAModelField: buildLoRAModelFieldInputTemplate,
|
||||
LLaVAModelField: buildLLaVAModelFieldInputTemplate,
|
||||
ModelIdentifierField: buildModelIdentifierFieldInputTemplate,
|
||||
MainModelField: buildMainModelFieldInputTemplate,
|
||||
SchedulerField: buildSchedulerFieldInputTemplate,
|
||||
|
@ -12977,6 +12977,12 @@ export type components = {
|
||||
* @default
|
||||
*/
|
||||
prompt?: string;
|
||||
/**
|
||||
* Image-to-Image Model
|
||||
* @description The VLLM model to use
|
||||
* @default null
|
||||
*/
|
||||
vllm_model?: components["schemas"]["ModelIdentifierField"];
|
||||
/**
|
||||
* type
|
||||
* @default llava_onevision_vllm
|
||||
@ -20814,7 +20820,7 @@ export type components = {
|
||||
* used, and the type will be ignored. They are included here for backwards compatibility.
|
||||
* @enum {string}
|
||||
*/
|
||||
UIType: "MainModelField" | "FluxMainModelField" | "SD3MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "FluxVAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "T5EncoderModelField" | "CLIPEmbedModelField" | "CLIPLEmbedModelField" | "CLIPGEmbedModelField" | "SpandrelImageToImageModelField" | "ControlLoRAModelField" | "SigLipModelField" | "FluxReduxModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict";
|
||||
UIType: "MainModelField" | "FluxMainModelField" | "SD3MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "FluxVAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "T5EncoderModelField" | "CLIPEmbedModelField" | "CLIPLEmbedModelField" | "CLIPGEmbedModelField" | "SpandrelImageToImageModelField" | "ControlLoRAModelField" | "SigLipModelField" | "FluxReduxModelField" | "LLaVAModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict";
|
||||
/** UNetField */
|
||||
UNetField: {
|
||||
/** @description Info to load unet submodel */
|
||||
|
Reference in New Issue
Block a user