mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
chore: Change PipelineModels to MainModels
This commit is contained in:
parent
2ad5a4ea46
commit
6c62f41f2e
@ -1,11 +1,13 @@
|
||||
from typing import Literal, Optional, Union, List
|
||||
from pydantic import BaseModel, Field
|
||||
import copy
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
model_name: str = Field(description="Info to load submodel")
|
||||
@ -43,19 +45,19 @@ class ModelLoaderOutput(BaseInvocationOutput):
|
||||
#fmt: on
|
||||
|
||||
|
||||
class PipelineModelField(BaseModel):
|
||||
"""Pipeline model field"""
|
||||
class MainModelField(BaseModel):
|
||||
"""Main model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
|
||||
class PipelineModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a pipeline model, outputting its submodels."""
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
type: Literal["pipeline_model_loader"] = "pipeline_model_loader"
|
||||
type: Literal["main_model_loader"] = "main_model_loader"
|
||||
|
||||
model: PipelineModelField = Field(description="The model to load")
|
||||
model: MainModelField = Field(description="The model to load")
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
|
@ -4,6 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { PartialAppConfig } from 'app/types/invokeai';
|
||||
import ImageUploader from 'common/components/ImageUploader';
|
||||
import GalleryDrawer from 'features/gallery/components/GalleryPanel';
|
||||
import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal';
|
||||
import Lightbox from 'features/lightbox/components/Lightbox';
|
||||
import SiteHeader from 'features/system/components/SiteHeader';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
@ -15,11 +16,10 @@ import InvokeTabs from 'features/ui/components/InvokeTabs';
|
||||
import ParametersDrawer from 'features/ui/components/ParametersDrawer';
|
||||
import i18n from 'i18n';
|
||||
import { ReactNode, memo, useEffect } from 'react';
|
||||
import DeleteBoardImagesModal from '../../features/gallery/components/Boards/DeleteBoardImagesModal';
|
||||
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
|
||||
import GlobalHotkeys from './GlobalHotkeys';
|
||||
import Toaster from './Toaster';
|
||||
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
|
||||
import DeleteBoardImagesModal from '../../features/gallery/components/Boards/DeleteBoardImagesModal';
|
||||
import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal';
|
||||
|
||||
const DEFAULT_CONFIG = {};
|
||||
|
||||
|
@ -6,13 +6,13 @@ import {
|
||||
ModelInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
|
||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { forEach, isString } from 'lodash-es';
|
||||
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
||||
import { forEach, isString } from 'lodash-es';
|
||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useListModelsQuery } from 'services/api/endpoints/models';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
const ModelInputFieldComponent = (
|
||||
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
|
||||
@ -22,18 +22,18 @@ const ModelInputFieldComponent = (
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { data: pipelineModels } = useListModelsQuery({
|
||||
const { data: mainModels } = useListModelsQuery({
|
||||
model_type: 'main',
|
||||
});
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!pipelineModels) {
|
||||
if (!mainModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(pipelineModels.entities, (model, id) => {
|
||||
forEach(mainModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
@ -46,11 +46,11 @@ const ModelInputFieldComponent = (
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [pipelineModels]);
|
||||
}, [mainModels]);
|
||||
|
||||
const selectedModel = useMemo(
|
||||
() => pipelineModels?.entities[field.value ?? pipelineModels.ids[0]],
|
||||
[pipelineModels?.entities, pipelineModels?.ids, field.value]
|
||||
() => mainModels?.entities[field.value ?? mainModels.ids[0]],
|
||||
[mainModels?.entities, mainModels?.ids, field.value]
|
||||
);
|
||||
|
||||
const handleValueChanged = useCallback(
|
||||
@ -71,18 +71,18 @@ const ModelInputFieldComponent = (
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (field.value && pipelineModels?.ids.includes(field.value)) {
|
||||
if (field.value && mainModels?.ids.includes(field.value)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const firstModel = pipelineModels?.ids[0];
|
||||
const firstModel = mainModels?.ids[0];
|
||||
|
||||
if (!isString(firstModel)) {
|
||||
return;
|
||||
}
|
||||
|
||||
handleValueChanged(firstModel);
|
||||
}, [field.value, handleValueChanged, pipelineModels?.ids]);
|
||||
}, [field.value, handleValueChanged, mainModels?.ids]);
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
|
@ -1,31 +1,25 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import {
|
||||
ImageDTO,
|
||||
ImageResizeInvocation,
|
||||
ImageToLatentsInvocation,
|
||||
RandomIntInvocation,
|
||||
RangeOfSizeInvocation,
|
||||
} from 'services/api/types';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
import {
|
||||
ITERATE,
|
||||
IMAGE_TO_IMAGE_GRAPH,
|
||||
IMAGE_TO_LATENTS,
|
||||
LATENTS_TO_IMAGE,
|
||||
PIPELINE_MODEL_LOADER,
|
||||
LATENTS_TO_LATENTS,
|
||||
MAIN_MODEL_LOADER,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
RANDOM_INT,
|
||||
RANGE_OF_SIZE,
|
||||
IMAGE_TO_IMAGE_GRAPH,
|
||||
IMAGE_TO_LATENTS,
|
||||
LATENTS_TO_LATENTS,
|
||||
RESIZE,
|
||||
} from './constants';
|
||||
import { set } from 'lodash-es';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'nodes' });
|
||||
|
||||
@ -52,7 +46,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
// The bounding box determines width and height, not the width and height params
|
||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||
|
||||
const model = modelIdToPipelineModelField(modelId);
|
||||
const model = modelIdToMainModelField(modelId);
|
||||
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
@ -81,9 +75,9 @@ export const buildCanvasImageToImageGraph = (
|
||||
type: 'noise',
|
||||
id: NOISE,
|
||||
},
|
||||
[PIPELINE_MODEL_LOADER]: {
|
||||
type: 'pipeline_model_loader',
|
||||
id: PIPELINE_MODEL_LOADER,
|
||||
[MAIN_MODEL_LOADER]: {
|
||||
type: 'main_model_loader',
|
||||
id: MAIN_MODEL_LOADER,
|
||||
model,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
@ -110,7 +104,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
edges: [
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
@ -120,7 +114,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
@ -130,7 +124,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
@ -170,7 +164,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
@ -180,7 +174,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
|
@ -1,23 +1,23 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import {
|
||||
ImageDTO,
|
||||
InpaintInvocation,
|
||||
RandomIntInvocation,
|
||||
RangeOfSizeInvocation,
|
||||
} from 'services/api/types';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||
import {
|
||||
INPAINT,
|
||||
INPAINT_GRAPH,
|
||||
ITERATE,
|
||||
PIPELINE_MODEL_LOADER,
|
||||
MAIN_MODEL_LOADER,
|
||||
NEGATIVE_CONDITIONING,
|
||||
POSITIVE_CONDITIONING,
|
||||
RANDOM_INT,
|
||||
RANGE_OF_SIZE,
|
||||
INPAINT_GRAPH,
|
||||
INPAINT,
|
||||
} from './constants';
|
||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'nodes' });
|
||||
|
||||
@ -55,7 +55,7 @@ export const buildCanvasInpaintGraph = (
|
||||
// We may need to set the inpaint width and height to scale the image
|
||||
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
|
||||
|
||||
const model = modelIdToPipelineModelField(modelId);
|
||||
const model = modelIdToMainModelField(modelId);
|
||||
|
||||
const graph: NonNullableGraph = {
|
||||
id: INPAINT_GRAPH,
|
||||
@ -101,9 +101,9 @@ export const buildCanvasInpaintGraph = (
|
||||
id: NEGATIVE_CONDITIONING,
|
||||
prompt: negativePrompt,
|
||||
},
|
||||
[PIPELINE_MODEL_LOADER]: {
|
||||
type: 'pipeline_model_loader',
|
||||
id: PIPELINE_MODEL_LOADER,
|
||||
[MAIN_MODEL_LOADER]: {
|
||||
type: 'main_model_loader',
|
||||
id: MAIN_MODEL_LOADER,
|
||||
model,
|
||||
},
|
||||
[RANGE_OF_SIZE]: {
|
||||
@ -142,7 +142,7 @@ export const buildCanvasInpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
@ -152,7 +152,7 @@ export const buildCanvasInpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
@ -162,7 +162,7 @@ export const buildCanvasInpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
@ -172,7 +172,7 @@ export const buildCanvasInpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
|
@ -1,21 +1,17 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api/types';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
import {
|
||||
ITERATE,
|
||||
LATENTS_TO_IMAGE,
|
||||
PIPELINE_MODEL_LOADER,
|
||||
MAIN_MODEL_LOADER,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
RANDOM_INT,
|
||||
RANGE_OF_SIZE,
|
||||
TEXT_TO_IMAGE_GRAPH,
|
||||
TEXT_TO_LATENTS,
|
||||
} from './constants';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Text to Image graph.
|
||||
@ -38,7 +34,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
// The bounding box determines width and height, not the width and height params
|
||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||
|
||||
const model = modelIdToPipelineModelField(modelId);
|
||||
const model = modelIdToMainModelField(modelId);
|
||||
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
@ -76,9 +72,9 @@ export const buildCanvasTextToImageGraph = (
|
||||
scheduler,
|
||||
steps,
|
||||
},
|
||||
[PIPELINE_MODEL_LOADER]: {
|
||||
type: 'pipeline_model_loader',
|
||||
id: PIPELINE_MODEL_LOADER,
|
||||
[MAIN_MODEL_LOADER]: {
|
||||
type: 'main_model_loader',
|
||||
id: MAIN_MODEL_LOADER,
|
||||
model,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
@ -109,7 +105,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
@ -119,7 +115,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
@ -129,7 +125,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
@ -149,7 +145,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
|
@ -1,28 +1,28 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import {
|
||||
ImageCollectionInvocation,
|
||||
ImageResizeInvocation,
|
||||
ImageToLatentsInvocation,
|
||||
IterateInvocation,
|
||||
} from 'services/api/types';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
import {
|
||||
IMAGE_TO_IMAGE_GRAPH,
|
||||
IMAGE_TO_LATENTS,
|
||||
LATENTS_TO_IMAGE,
|
||||
PIPELINE_MODEL_LOADER,
|
||||
LATENTS_TO_LATENTS,
|
||||
MAIN_MODEL_LOADER,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
IMAGE_TO_IMAGE_GRAPH,
|
||||
IMAGE_TO_LATENTS,
|
||||
LATENTS_TO_LATENTS,
|
||||
RESIZE,
|
||||
IMAGE_COLLECTION,
|
||||
IMAGE_COLLECTION_ITERATE,
|
||||
} from './constants';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'nodes' });
|
||||
|
||||
@ -69,7 +69,7 @@ export const buildLinearImageToImageGraph = (
|
||||
throw new Error('No initial image found in state');
|
||||
}
|
||||
|
||||
const model = modelIdToPipelineModelField(modelId);
|
||||
const model = modelIdToMainModelField(modelId);
|
||||
|
||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||
const graph: NonNullableGraph = {
|
||||
@ -89,9 +89,9 @@ export const buildLinearImageToImageGraph = (
|
||||
type: 'noise',
|
||||
id: NOISE,
|
||||
},
|
||||
[PIPELINE_MODEL_LOADER]: {
|
||||
type: 'pipeline_model_loader',
|
||||
id: PIPELINE_MODEL_LOADER,
|
||||
[MAIN_MODEL_LOADER]: {
|
||||
type: 'main_model_loader',
|
||||
id: MAIN_MODEL_LOADER,
|
||||
model,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
@ -118,7 +118,7 @@ export const buildLinearImageToImageGraph = (
|
||||
edges: [
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
@ -128,7 +128,7 @@ export const buildLinearImageToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
@ -138,7 +138,7 @@ export const buildLinearImageToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
@ -178,7 +178,7 @@ export const buildLinearImageToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
@ -188,7 +188,7 @@ export const buildLinearImageToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
|
@ -1,17 +1,17 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
import {
|
||||
LATENTS_TO_IMAGE,
|
||||
PIPELINE_MODEL_LOADER,
|
||||
MAIN_MODEL_LOADER,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
TEXT_TO_IMAGE_GRAPH,
|
||||
TEXT_TO_LATENTS,
|
||||
} from './constants';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
|
||||
export const buildLinearTextToImageGraph = (
|
||||
state: RootState
|
||||
@ -27,7 +27,7 @@ export const buildLinearTextToImageGraph = (
|
||||
height,
|
||||
} = state.generation;
|
||||
|
||||
const model = modelIdToPipelineModelField(modelId);
|
||||
const model = modelIdToMainModelField(modelId);
|
||||
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
@ -65,9 +65,9 @@ export const buildLinearTextToImageGraph = (
|
||||
scheduler,
|
||||
steps,
|
||||
},
|
||||
[PIPELINE_MODEL_LOADER]: {
|
||||
type: 'pipeline_model_loader',
|
||||
id: PIPELINE_MODEL_LOADER,
|
||||
[MAIN_MODEL_LOADER]: {
|
||||
type: 'main_model_loader',
|
||||
id: MAIN_MODEL_LOADER,
|
||||
model,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
@ -98,7 +98,7 @@ export const buildLinearTextToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
@ -108,7 +108,7 @@ export const buildLinearTextToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
@ -118,7 +118,7 @@ export const buildLinearTextToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
@ -138,7 +138,7 @@ export const buildLinearTextToImageGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: PIPELINE_MODEL_LOADER,
|
||||
node_id: MAIN_MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
|
@ -1,10 +1,10 @@
|
||||
import { Graph } from 'services/api/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { cloneDeep, omit, reduce } from 'lodash-es';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { InputFieldValue } from 'features/nodes/types/types';
|
||||
import { cloneDeep, omit, reduce } from 'lodash-es';
|
||||
import { Graph } from 'services/api/types';
|
||||
import { AnyInvocation } from 'services/events/types';
|
||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { modelIdToMainModelField } from '../modelIdToMainModelField';
|
||||
|
||||
/**
|
||||
* We need to do special handling for some fields
|
||||
@ -27,7 +27,7 @@ export const parseFieldValue = (field: InputFieldValue) => {
|
||||
|
||||
if (field.type === 'model') {
|
||||
if (field.value) {
|
||||
return modelIdToPipelineModelField(field.value);
|
||||
return modelIdToMainModelField(field.value);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -7,7 +7,7 @@ export const NOISE = 'noise';
|
||||
export const RANDOM_INT = 'rand_int';
|
||||
export const RANGE_OF_SIZE = 'range_of_size';
|
||||
export const ITERATE = 'iterate';
|
||||
export const PIPELINE_MODEL_LOADER = 'pipeline_model_loader';
|
||||
export const MAIN_MODEL_LOADER = 'main_model_loader';
|
||||
export const IMAGE_TO_LATENTS = 'image_to_latents';
|
||||
export const LATENTS_TO_LATENTS = 'latents_to_latents';
|
||||
export const RESIZE = 'resize_image';
|
||||
|
@ -0,0 +1,16 @@
|
||||
import { BaseModelType, MainModelField } from 'services/api/types';
|
||||
|
||||
/**
|
||||
* Crudely converts a model id to a main model field
|
||||
* TODO: Make better
|
||||
*/
|
||||
export const modelIdToMainModelField = (modelId: string): MainModelField => {
|
||||
const [base_model, model_type, model_name] = modelId.split('/');
|
||||
|
||||
const field: MainModelField = {
|
||||
base_model: base_model as BaseModelType,
|
||||
model_name,
|
||||
};
|
||||
|
||||
return field;
|
||||
};
|
@ -1,18 +0,0 @@
|
||||
import { BaseModelType, PipelineModelField } from 'services/api/types';
|
||||
|
||||
/**
|
||||
* Crudely converts a model id to a pipeline model field
|
||||
* TODO: Make better
|
||||
*/
|
||||
export const modelIdToPipelineModelField = (
|
||||
modelId: string
|
||||
): PipelineModelField => {
|
||||
const [base_model, model_type, model_name] = modelId.split('/');
|
||||
|
||||
const field: PipelineModelField = {
|
||||
base_model: base_model as BaseModelType,
|
||||
model_name,
|
||||
};
|
||||
|
||||
return field;
|
||||
};
|
@ -5,9 +5,9 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { modelSelected } from 'features/parameters/store/generationSlice';
|
||||
|
||||
import { forEach, isString } from 'lodash-es';
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { forEach, isString } from 'lodash-es';
|
||||
import { useListModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
export const MODEL_TYPE_MAP = {
|
||||
@ -23,18 +23,18 @@ const ModelSelect = () => {
|
||||
(state: RootState) => state.generation.model
|
||||
);
|
||||
|
||||
const { data: pipelineModels, isLoading } = useListModelsQuery({
|
||||
const { data: mainModels, isLoading } = useListModelsQuery({
|
||||
model_type: 'main',
|
||||
});
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!pipelineModels) {
|
||||
if (!mainModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(pipelineModels.entities, (model, id) => {
|
||||
forEach(mainModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
@ -47,11 +47,11 @@ const ModelSelect = () => {
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [pipelineModels]);
|
||||
}, [mainModels]);
|
||||
|
||||
const selectedModel = useMemo(
|
||||
() => pipelineModels?.entities[selectedModelId],
|
||||
[pipelineModels?.entities, selectedModelId]
|
||||
() => mainModels?.entities[selectedModelId],
|
||||
[mainModels?.entities, selectedModelId]
|
||||
);
|
||||
|
||||
const handleChangeModel = useCallback(
|
||||
@ -65,20 +65,18 @@ const ModelSelect = () => {
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
// If the selected model is not in the list of models, select the first one
|
||||
// Handles first-run setting of models, and the user deleting the previously-selected model
|
||||
if (selectedModelId && pipelineModels?.ids.includes(selectedModelId)) {
|
||||
if (selectedModelId && mainModels?.ids.includes(selectedModelId)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const firstModel = pipelineModels?.ids[0];
|
||||
const firstModel = mainModels?.ids[0];
|
||||
|
||||
if (!isString(firstModel)) {
|
||||
return;
|
||||
}
|
||||
|
||||
handleChangeModel(firstModel);
|
||||
}, [handleChangeModel, pipelineModels?.ids, selectedModelId]);
|
||||
}, [handleChangeModel, mainModels?.ids, selectedModelId]);
|
||||
|
||||
return isLoading ? (
|
||||
<IAIMantineSelect
|
||||
|
@ -8,8 +8,8 @@ import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
|
||||
import ModelList from './ModelManagerPanel/ModelList';
|
||||
|
||||
export default function ModelManagerPanel() {
|
||||
const { data: pipelineModels } = useListModelsQuery({
|
||||
model_type: 'pipeline',
|
||||
const { data: mainModels } = useListModelsQuery({
|
||||
model_type: 'main',
|
||||
});
|
||||
|
||||
const openModel = useAppSelector(
|
||||
@ -17,20 +17,20 @@ export default function ModelManagerPanel() {
|
||||
);
|
||||
|
||||
const renderModelEditTabs = () => {
|
||||
if (!openModel || !pipelineModels) return;
|
||||
if (!openModel || !mainModels) return;
|
||||
|
||||
if (pipelineModels['entities'][openModel]['model_format'] === 'diffusers') {
|
||||
if (mainModels['entities'][openModel]['model_format'] === 'diffusers') {
|
||||
return (
|
||||
<DiffusersModelEdit
|
||||
modelToEdit={openModel}
|
||||
retrievedModel={pipelineModels['entities'][openModel]}
|
||||
retrievedModel={mainModels['entities'][openModel]}
|
||||
/>
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
<CheckpointModelEdit
|
||||
modelToEdit={openModel}
|
||||
retrievedModel={pipelineModels['entities'][openModel]}
|
||||
retrievedModel={mainModels['entities'][openModel]}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
@ -36,8 +36,8 @@ function ModelFilterButton({
|
||||
}
|
||||
|
||||
const ModelList = () => {
|
||||
const { data: pipelineModels } = useListModelsQuery({
|
||||
model_type: 'pipeline',
|
||||
const { data: mainModels } = useListModelsQuery({
|
||||
model_type: 'main',
|
||||
});
|
||||
|
||||
const [renderModelList, setRenderModelList] = React.useState<boolean>(false);
|
||||
@ -70,9 +70,9 @@ const ModelList = () => {
|
||||
const filteredModelListItemsToRender: ReactNode[] = [];
|
||||
const localFilteredModelListItemsToRender: ReactNode[] = [];
|
||||
|
||||
if (!pipelineModels) return;
|
||||
if (!mainModels) return;
|
||||
|
||||
const modelList = pipelineModels.entities;
|
||||
const modelList = mainModels.entities;
|
||||
|
||||
Object.keys(modelList).forEach((model, i) => {
|
||||
if (
|
||||
@ -179,7 +179,7 @@ const ModelList = () => {
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}, [pipelineModels, searchText, t, isSelectedFilter]);
|
||||
}, [mainModels, searchText, t, isSelectedFilter]);
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
|
||||
|
3428
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
3428
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
File diff suppressed because it is too large
Load Diff
@ -33,7 +33,7 @@ export type OffsetPaginatedResults_ImageDTO_ =
|
||||
// Models
|
||||
export type ModelType = S<'ModelType'>;
|
||||
export type BaseModelType = S<'BaseModelType'>;
|
||||
export type PipelineModelField = S<'PipelineModelField'>;
|
||||
export type MainModelField = S<'MainModelField'>;
|
||||
export type ModelsList = S<'ModelsList'>;
|
||||
|
||||
// Graphs
|
||||
@ -57,8 +57,8 @@ export type TextToLatentsInvocation = N<'TextToLatentsInvocation'>;
|
||||
export type LatentsToLatentsInvocation = N<'LatentsToLatentsInvocation'>;
|
||||
export type ImageToLatentsInvocation = N<'ImageToLatentsInvocation'>;
|
||||
export type LatentsToImageInvocation = N<'LatentsToImageInvocation'>;
|
||||
export type PipelineModelLoaderInvocation = N<'PipelineModelLoaderInvocation'>;
|
||||
export type ImageCollectionInvocation = N<'ImageCollectionInvocation'>;
|
||||
export type MainModelLoaderInvocation = N<'MainModelLoaderInvocation'>;
|
||||
|
||||
// ControlNet Nodes
|
||||
export type ControlNetInvocation = N<'ControlNetInvocation'>;
|
||||
|
Loading…
Reference in New Issue
Block a user