chore: Change PipelineModels to MainModels

This commit is contained in:
blessedcoolant 2023-06-29 08:13:36 +12:00 committed by psychedelicious
parent 2ad5a4ea46
commit 6c62f41f2e
17 changed files with 2061 additions and 1673 deletions

View File

@ -1,11 +1,13 @@
from typing import Literal, Optional, Union, List
from pydantic import BaseModel, Field
import copy import copy
from typing import List, Literal, Optional, Union
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig from pydantic import BaseModel, Field
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.model_management import BaseModelType, ModelType, SubModelType from ...backend.model_management import BaseModelType, ModelType, SubModelType
from ...backend.util.devices import choose_torch_device, torch_dtype
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
class ModelInfo(BaseModel): class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load submodel") model_name: str = Field(description="Info to load submodel")
@ -43,19 +45,19 @@ class ModelLoaderOutput(BaseInvocationOutput):
#fmt: on #fmt: on
class PipelineModelField(BaseModel): class MainModelField(BaseModel):
"""Pipeline model field""" """Main model field"""
model_name: str = Field(description="Name of the model") model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model") base_model: BaseModelType = Field(description="Base model")
class PipelineModelLoaderInvocation(BaseInvocation): class MainModelLoaderInvocation(BaseInvocation):
"""Loads a pipeline model, outputting its submodels.""" """Loads a main model, outputting its submodels."""
type: Literal["pipeline_model_loader"] = "pipeline_model_loader" type: Literal["main_model_loader"] = "main_model_loader"
model: PipelineModelField = Field(description="The model to load") model: MainModelField = Field(description="The model to load")
# TODO: precision? # TODO: precision?
# Schema customisation # Schema customisation

View File

@ -4,6 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { PartialAppConfig } from 'app/types/invokeai'; import { PartialAppConfig } from 'app/types/invokeai';
import ImageUploader from 'common/components/ImageUploader'; import ImageUploader from 'common/components/ImageUploader';
import GalleryDrawer from 'features/gallery/components/GalleryPanel'; import GalleryDrawer from 'features/gallery/components/GalleryPanel';
import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal';
import Lightbox from 'features/lightbox/components/Lightbox'; import Lightbox from 'features/lightbox/components/Lightbox';
import SiteHeader from 'features/system/components/SiteHeader'; import SiteHeader from 'features/system/components/SiteHeader';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
@ -15,11 +16,10 @@ import InvokeTabs from 'features/ui/components/InvokeTabs';
import ParametersDrawer from 'features/ui/components/ParametersDrawer'; import ParametersDrawer from 'features/ui/components/ParametersDrawer';
import i18n from 'i18n'; import i18n from 'i18n';
import { ReactNode, memo, useEffect } from 'react'; import { ReactNode, memo, useEffect } from 'react';
import DeleteBoardImagesModal from '../../features/gallery/components/Boards/DeleteBoardImagesModal';
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
import GlobalHotkeys from './GlobalHotkeys'; import GlobalHotkeys from './GlobalHotkeys';
import Toaster from './Toaster'; import Toaster from './Toaster';
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
import DeleteBoardImagesModal from '../../features/gallery/components/Boards/DeleteBoardImagesModal';
import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal';
const DEFAULT_CONFIG = {}; const DEFAULT_CONFIG = {};

View File

@ -6,13 +6,13 @@ import {
ModelInputFieldValue, ModelInputFieldValue,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { FieldComponentProps } from './types';
import { forEach, isString } from 'lodash-es';
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
import { forEach, isString } from 'lodash-es';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useListModelsQuery } from 'services/api/endpoints/models'; import { useListModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
const ModelInputFieldComponent = ( const ModelInputFieldComponent = (
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate> props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
@ -22,18 +22,18 @@ const ModelInputFieldComponent = (
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { data: pipelineModels } = useListModelsQuery({ const { data: mainModels } = useListModelsQuery({
model_type: 'main', model_type: 'main',
}); });
const data = useMemo(() => { const data = useMemo(() => {
if (!pipelineModels) { if (!mainModels) {
return []; return [];
} }
const data: SelectItem[] = []; const data: SelectItem[] = [];
forEach(pipelineModels.entities, (model, id) => { forEach(mainModels.entities, (model, id) => {
if (!model) { if (!model) {
return; return;
} }
@ -46,11 +46,11 @@ const ModelInputFieldComponent = (
}); });
return data; return data;
}, [pipelineModels]); }, [mainModels]);
const selectedModel = useMemo( const selectedModel = useMemo(
() => pipelineModels?.entities[field.value ?? pipelineModels.ids[0]], () => mainModels?.entities[field.value ?? mainModels.ids[0]],
[pipelineModels?.entities, pipelineModels?.ids, field.value] [mainModels?.entities, mainModels?.ids, field.value]
); );
const handleValueChanged = useCallback( const handleValueChanged = useCallback(
@ -71,18 +71,18 @@ const ModelInputFieldComponent = (
); );
useEffect(() => { useEffect(() => {
if (field.value && pipelineModels?.ids.includes(field.value)) { if (field.value && mainModels?.ids.includes(field.value)) {
return; return;
} }
const firstModel = pipelineModels?.ids[0]; const firstModel = mainModels?.ids[0];
if (!isString(firstModel)) { if (!isString(firstModel)) {
return; return;
} }
handleValueChanged(firstModel); handleValueChanged(firstModel);
}, [field.value, handleValueChanged, pipelineModels?.ids]); }, [field.value, handleValueChanged, mainModels?.ids]);
return ( return (
<IAIMantineSelect <IAIMantineSelect

View File

@ -1,31 +1,25 @@
import { log } from 'app/logging/useLogger';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { import {
ImageDTO, ImageDTO,
ImageResizeInvocation, ImageResizeInvocation,
ImageToLatentsInvocation, ImageToLatentsInvocation,
RandomIntInvocation,
RangeOfSizeInvocation,
} from 'services/api/types'; } from 'services/api/types';
import { NonNullableGraph } from 'features/nodes/types/types'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { log } from 'app/logging/useLogger'; import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { import {
ITERATE, IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
PIPELINE_MODEL_LOADER, LATENTS_TO_LATENTS,
MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS,
LATENTS_TO_LATENTS,
RESIZE, RESIZE,
} from './constants'; } from './constants';
import { set } from 'lodash-es';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
const moduleLog = log.child({ namespace: 'nodes' }); const moduleLog = log.child({ namespace: 'nodes' });
@ -52,7 +46,7 @@ export const buildCanvasImageToImageGraph = (
// The bounding box determines width and height, not the width and height params // The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions; const { width, height } = state.canvas.boundingBoxDimensions;
const model = modelIdToPipelineModelField(modelId); const model = modelIdToMainModelField(modelId);
/** /**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
@ -81,9 +75,9 @@ export const buildCanvasImageToImageGraph = (
type: 'noise', type: 'noise',
id: NOISE, id: NOISE,
}, },
[PIPELINE_MODEL_LOADER]: { [MAIN_MODEL_LOADER]: {
type: 'pipeline_model_loader', type: 'main_model_loader',
id: PIPELINE_MODEL_LOADER, id: MAIN_MODEL_LOADER,
model, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
@ -110,7 +104,7 @@ export const buildCanvasImageToImageGraph = (
edges: [ edges: [
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -120,7 +114,7 @@ export const buildCanvasImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -130,7 +124,7 @@ export const buildCanvasImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {
@ -170,7 +164,7 @@ export const buildCanvasImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {
@ -180,7 +174,7 @@ export const buildCanvasImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'unet', field: 'unet',
}, },
destination: { destination: {

View File

@ -1,23 +1,23 @@
import { log } from 'app/logging/useLogger';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { import {
ImageDTO, ImageDTO,
InpaintInvocation, InpaintInvocation,
RandomIntInvocation, RandomIntInvocation,
RangeOfSizeInvocation, RangeOfSizeInvocation,
} from 'services/api/types'; } from 'services/api/types';
import { NonNullableGraph } from 'features/nodes/types/types'; import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { log } from 'app/logging/useLogger';
import { import {
INPAINT,
INPAINT_GRAPH,
ITERATE, ITERATE,
PIPELINE_MODEL_LOADER, MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RANDOM_INT, RANDOM_INT,
RANGE_OF_SIZE, RANGE_OF_SIZE,
INPAINT_GRAPH,
INPAINT,
} from './constants'; } from './constants';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
const moduleLog = log.child({ namespace: 'nodes' }); const moduleLog = log.child({ namespace: 'nodes' });
@ -55,7 +55,7 @@ export const buildCanvasInpaintGraph = (
// We may need to set the inpaint width and height to scale the image // We may need to set the inpaint width and height to scale the image
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas; const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const model = modelIdToPipelineModelField(modelId); const model = modelIdToMainModelField(modelId);
const graph: NonNullableGraph = { const graph: NonNullableGraph = {
id: INPAINT_GRAPH, id: INPAINT_GRAPH,
@ -101,9 +101,9 @@ export const buildCanvasInpaintGraph = (
id: NEGATIVE_CONDITIONING, id: NEGATIVE_CONDITIONING,
prompt: negativePrompt, prompt: negativePrompt,
}, },
[PIPELINE_MODEL_LOADER]: { [MAIN_MODEL_LOADER]: {
type: 'pipeline_model_loader', type: 'main_model_loader',
id: PIPELINE_MODEL_LOADER, id: MAIN_MODEL_LOADER,
model, model,
}, },
[RANGE_OF_SIZE]: { [RANGE_OF_SIZE]: {
@ -142,7 +142,7 @@ export const buildCanvasInpaintGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -152,7 +152,7 @@ export const buildCanvasInpaintGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -162,7 +162,7 @@ export const buildCanvasInpaintGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -172,7 +172,7 @@ export const buildCanvasInpaintGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {

View File

@ -1,21 +1,17 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api/types'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { import {
ITERATE,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
PIPELINE_MODEL_LOADER, MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
TEXT_TO_IMAGE_GRAPH, TEXT_TO_IMAGE_GRAPH,
TEXT_TO_LATENTS, TEXT_TO_LATENTS,
} from './constants'; } from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
/** /**
* Builds the Canvas tab's Text to Image graph. * Builds the Canvas tab's Text to Image graph.
@ -38,7 +34,7 @@ export const buildCanvasTextToImageGraph = (
// The bounding box determines width and height, not the width and height params // The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions; const { width, height } = state.canvas.boundingBoxDimensions;
const model = modelIdToPipelineModelField(modelId); const model = modelIdToMainModelField(modelId);
/** /**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
@ -76,9 +72,9 @@ export const buildCanvasTextToImageGraph = (
scheduler, scheduler,
steps, steps,
}, },
[PIPELINE_MODEL_LOADER]: { [MAIN_MODEL_LOADER]: {
type: 'pipeline_model_loader', type: 'main_model_loader',
id: PIPELINE_MODEL_LOADER, id: MAIN_MODEL_LOADER,
model, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
@ -109,7 +105,7 @@ export const buildCanvasTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -119,7 +115,7 @@ export const buildCanvasTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -129,7 +125,7 @@ export const buildCanvasTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -149,7 +145,7 @@ export const buildCanvasTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {

View File

@ -1,28 +1,28 @@
import { log } from 'app/logging/useLogger';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { import {
ImageCollectionInvocation, ImageCollectionInvocation,
ImageResizeInvocation, ImageResizeInvocation,
ImageToLatentsInvocation, ImageToLatentsInvocation,
IterateInvocation, IterateInvocation,
} from 'services/api/types'; } from 'services/api/types';
import { NonNullableGraph } from 'features/nodes/types/types'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { log } from 'app/logging/useLogger'; import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { import {
IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
PIPELINE_MODEL_LOADER, LATENTS_TO_LATENTS,
MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS,
LATENTS_TO_LATENTS,
RESIZE, RESIZE,
IMAGE_COLLECTION, IMAGE_COLLECTION,
IMAGE_COLLECTION_ITERATE, IMAGE_COLLECTION_ITERATE,
} from './constants'; } from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
const moduleLog = log.child({ namespace: 'nodes' }); const moduleLog = log.child({ namespace: 'nodes' });
@ -69,7 +69,7 @@ export const buildLinearImageToImageGraph = (
throw new Error('No initial image found in state'); throw new Error('No initial image found in state');
} }
const model = modelIdToPipelineModelField(modelId); const model = modelIdToMainModelField(modelId);
// copy-pasted graph from node editor, filled in with state values & friendly node ids // copy-pasted graph from node editor, filled in with state values & friendly node ids
const graph: NonNullableGraph = { const graph: NonNullableGraph = {
@ -89,9 +89,9 @@ export const buildLinearImageToImageGraph = (
type: 'noise', type: 'noise',
id: NOISE, id: NOISE,
}, },
[PIPELINE_MODEL_LOADER]: { [MAIN_MODEL_LOADER]: {
type: 'pipeline_model_loader', type: 'main_model_loader',
id: PIPELINE_MODEL_LOADER, id: MAIN_MODEL_LOADER,
model, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
@ -118,7 +118,7 @@ export const buildLinearImageToImageGraph = (
edges: [ edges: [
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -128,7 +128,7 @@ export const buildLinearImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -138,7 +138,7 @@ export const buildLinearImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {
@ -178,7 +178,7 @@ export const buildLinearImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {
@ -188,7 +188,7 @@ export const buildLinearImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'unet', field: 'unet',
}, },
destination: { destination: {

View File

@ -1,17 +1,17 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { import {
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
PIPELINE_MODEL_LOADER, MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
TEXT_TO_IMAGE_GRAPH, TEXT_TO_IMAGE_GRAPH,
TEXT_TO_LATENTS, TEXT_TO_LATENTS,
} from './constants'; } from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
export const buildLinearTextToImageGraph = ( export const buildLinearTextToImageGraph = (
state: RootState state: RootState
@ -27,7 +27,7 @@ export const buildLinearTextToImageGraph = (
height, height,
} = state.generation; } = state.generation;
const model = modelIdToPipelineModelField(modelId); const model = modelIdToMainModelField(modelId);
/** /**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
@ -65,9 +65,9 @@ export const buildLinearTextToImageGraph = (
scheduler, scheduler,
steps, steps,
}, },
[PIPELINE_MODEL_LOADER]: { [MAIN_MODEL_LOADER]: {
type: 'pipeline_model_loader', type: 'main_model_loader',
id: PIPELINE_MODEL_LOADER, id: MAIN_MODEL_LOADER,
model, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
@ -98,7 +98,7 @@ export const buildLinearTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -108,7 +108,7 @@ export const buildLinearTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -118,7 +118,7 @@ export const buildLinearTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -138,7 +138,7 @@ export const buildLinearTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: PIPELINE_MODEL_LOADER, node_id: MAIN_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {

View File

@ -1,10 +1,10 @@
import { Graph } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
import { cloneDeep, omit, reduce } from 'lodash-es';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { InputFieldValue } from 'features/nodes/types/types'; import { InputFieldValue } from 'features/nodes/types/types';
import { cloneDeep, omit, reduce } from 'lodash-es';
import { Graph } from 'services/api/types';
import { AnyInvocation } from 'services/events/types'; import { AnyInvocation } from 'services/events/types';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; import { v4 as uuidv4 } from 'uuid';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
/** /**
* We need to do special handling for some fields * We need to do special handling for some fields
@ -27,7 +27,7 @@ export const parseFieldValue = (field: InputFieldValue) => {
if (field.type === 'model') { if (field.type === 'model') {
if (field.value) { if (field.value) {
return modelIdToPipelineModelField(field.value); return modelIdToMainModelField(field.value);
} }
} }

View File

@ -7,7 +7,7 @@ export const NOISE = 'noise';
export const RANDOM_INT = 'rand_int'; export const RANDOM_INT = 'rand_int';
export const RANGE_OF_SIZE = 'range_of_size'; export const RANGE_OF_SIZE = 'range_of_size';
export const ITERATE = 'iterate'; export const ITERATE = 'iterate';
export const PIPELINE_MODEL_LOADER = 'pipeline_model_loader'; export const MAIN_MODEL_LOADER = 'main_model_loader';
export const IMAGE_TO_LATENTS = 'image_to_latents'; export const IMAGE_TO_LATENTS = 'image_to_latents';
export const LATENTS_TO_LATENTS = 'latents_to_latents'; export const LATENTS_TO_LATENTS = 'latents_to_latents';
export const RESIZE = 'resize_image'; export const RESIZE = 'resize_image';

View File

@ -0,0 +1,16 @@
import { BaseModelType, MainModelField } from 'services/api/types';
/**
* Crudely converts a model id to a main model field
* TODO: Make better
*/
export const modelIdToMainModelField = (modelId: string): MainModelField => {
const [base_model, model_type, model_name] = modelId.split('/');
const field: MainModelField = {
base_model: base_model as BaseModelType,
model_name,
};
return field;
};

View File

@ -1,18 +0,0 @@
import { BaseModelType, PipelineModelField } from 'services/api/types';
/**
* Crudely converts a model id to a pipeline model field
* TODO: Make better
*/
export const modelIdToPipelineModelField = (
modelId: string
): PipelineModelField => {
const [base_model, model_type, model_name] = modelId.split('/');
const field: PipelineModelField = {
base_model: base_model as BaseModelType,
model_name,
};
return field;
};

View File

@ -5,9 +5,9 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { modelSelected } from 'features/parameters/store/generationSlice'; import { modelSelected } from 'features/parameters/store/generationSlice';
import { forEach, isString } from 'lodash-es';
import { SelectItem } from '@mantine/core'; import { SelectItem } from '@mantine/core';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { forEach, isString } from 'lodash-es';
import { useListModelsQuery } from 'services/api/endpoints/models'; import { useListModelsQuery } from 'services/api/endpoints/models';
export const MODEL_TYPE_MAP = { export const MODEL_TYPE_MAP = {
@ -23,18 +23,18 @@ const ModelSelect = () => {
(state: RootState) => state.generation.model (state: RootState) => state.generation.model
); );
const { data: pipelineModels, isLoading } = useListModelsQuery({ const { data: mainModels, isLoading } = useListModelsQuery({
model_type: 'main', model_type: 'main',
}); });
const data = useMemo(() => { const data = useMemo(() => {
if (!pipelineModels) { if (!mainModels) {
return []; return [];
} }
const data: SelectItem[] = []; const data: SelectItem[] = [];
forEach(pipelineModels.entities, (model, id) => { forEach(mainModels.entities, (model, id) => {
if (!model) { if (!model) {
return; return;
} }
@ -47,11 +47,11 @@ const ModelSelect = () => {
}); });
return data; return data;
}, [pipelineModels]); }, [mainModels]);
const selectedModel = useMemo( const selectedModel = useMemo(
() => pipelineModels?.entities[selectedModelId], () => mainModels?.entities[selectedModelId],
[pipelineModels?.entities, selectedModelId] [mainModels?.entities, selectedModelId]
); );
const handleChangeModel = useCallback( const handleChangeModel = useCallback(
@ -65,20 +65,18 @@ const ModelSelect = () => {
); );
useEffect(() => { useEffect(() => {
// If the selected model is not in the list of models, select the first one if (selectedModelId && mainModels?.ids.includes(selectedModelId)) {
// Handles first-run setting of models, and the user deleting the previously-selected model
if (selectedModelId && pipelineModels?.ids.includes(selectedModelId)) {
return; return;
} }
const firstModel = pipelineModels?.ids[0]; const firstModel = mainModels?.ids[0];
if (!isString(firstModel)) { if (!isString(firstModel)) {
return; return;
} }
handleChangeModel(firstModel); handleChangeModel(firstModel);
}, [handleChangeModel, pipelineModels?.ids, selectedModelId]); }, [handleChangeModel, mainModels?.ids, selectedModelId]);
return isLoading ? ( return isLoading ? (
<IAIMantineSelect <IAIMantineSelect

View File

@ -8,8 +8,8 @@ import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
import ModelList from './ModelManagerPanel/ModelList'; import ModelList from './ModelManagerPanel/ModelList';
export default function ModelManagerPanel() { export default function ModelManagerPanel() {
const { data: pipelineModels } = useListModelsQuery({ const { data: mainModels } = useListModelsQuery({
model_type: 'pipeline', model_type: 'main',
}); });
const openModel = useAppSelector( const openModel = useAppSelector(
@ -17,20 +17,20 @@ export default function ModelManagerPanel() {
); );
const renderModelEditTabs = () => { const renderModelEditTabs = () => {
if (!openModel || !pipelineModels) return; if (!openModel || !mainModels) return;
if (pipelineModels['entities'][openModel]['model_format'] === 'diffusers') { if (mainModels['entities'][openModel]['model_format'] === 'diffusers') {
return ( return (
<DiffusersModelEdit <DiffusersModelEdit
modelToEdit={openModel} modelToEdit={openModel}
retrievedModel={pipelineModels['entities'][openModel]} retrievedModel={mainModels['entities'][openModel]}
/> />
); );
} else { } else {
return ( return (
<CheckpointModelEdit <CheckpointModelEdit
modelToEdit={openModel} modelToEdit={openModel}
retrievedModel={pipelineModels['entities'][openModel]} retrievedModel={mainModels['entities'][openModel]}
/> />
); );
} }

View File

@ -36,8 +36,8 @@ function ModelFilterButton({
} }
const ModelList = () => { const ModelList = () => {
const { data: pipelineModels } = useListModelsQuery({ const { data: mainModels } = useListModelsQuery({
model_type: 'pipeline', model_type: 'main',
}); });
const [renderModelList, setRenderModelList] = React.useState<boolean>(false); const [renderModelList, setRenderModelList] = React.useState<boolean>(false);
@ -70,9 +70,9 @@ const ModelList = () => {
const filteredModelListItemsToRender: ReactNode[] = []; const filteredModelListItemsToRender: ReactNode[] = [];
const localFilteredModelListItemsToRender: ReactNode[] = []; const localFilteredModelListItemsToRender: ReactNode[] = [];
if (!pipelineModels) return; if (!mainModels) return;
const modelList = pipelineModels.entities; const modelList = mainModels.entities;
Object.keys(modelList).forEach((model, i) => { Object.keys(modelList).forEach((model, i) => {
if ( if (
@ -179,7 +179,7 @@ const ModelList = () => {
)} )}
</Flex> </Flex>
); );
}, [pipelineModels, searchText, t, isSelectedFilter]); }, [mainModels, searchText, t, isSelectedFilter]);
return ( return (
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%"> <Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">

File diff suppressed because it is too large Load Diff

View File

@ -33,7 +33,7 @@ export type OffsetPaginatedResults_ImageDTO_ =
// Models // Models
export type ModelType = S<'ModelType'>; export type ModelType = S<'ModelType'>;
export type BaseModelType = S<'BaseModelType'>; export type BaseModelType = S<'BaseModelType'>;
export type PipelineModelField = S<'PipelineModelField'>; export type MainModelField = S<'MainModelField'>;
export type ModelsList = S<'ModelsList'>; export type ModelsList = S<'ModelsList'>;
// Graphs // Graphs
@ -57,8 +57,8 @@ export type TextToLatentsInvocation = N<'TextToLatentsInvocation'>;
export type LatentsToLatentsInvocation = N<'LatentsToLatentsInvocation'>; export type LatentsToLatentsInvocation = N<'LatentsToLatentsInvocation'>;
export type ImageToLatentsInvocation = N<'ImageToLatentsInvocation'>; export type ImageToLatentsInvocation = N<'ImageToLatentsInvocation'>;
export type LatentsToImageInvocation = N<'LatentsToImageInvocation'>; export type LatentsToImageInvocation = N<'LatentsToImageInvocation'>;
export type PipelineModelLoaderInvocation = N<'PipelineModelLoaderInvocation'>;
export type ImageCollectionInvocation = N<'ImageCollectionInvocation'>; export type ImageCollectionInvocation = N<'ImageCollectionInvocation'>;
export type MainModelLoaderInvocation = N<'MainModelLoaderInvocation'>;
// ControlNet Nodes // ControlNet Nodes
export type ControlNetInvocation = N<'ControlNetInvocation'>; export type ControlNetInvocation = N<'ControlNetInvocation'>;