chore: Change PipelineModels to MainModels

This commit is contained in:
blessedcoolant 2023-06-29 08:13:36 +12:00 committed by psychedelicious
parent 2ad5a4ea46
commit 6c62f41f2e
17 changed files with 2061 additions and 1673 deletions

View File

@ -1,11 +1,13 @@
from typing import Literal, Optional, Union, List
from pydantic import BaseModel, Field
import copy
from typing import List, Literal, Optional, Union
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from pydantic import BaseModel, Field
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.model_management import BaseModelType, ModelType, SubModelType
from ...backend.util.devices import choose_torch_device, torch_dtype
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load submodel")
@ -43,19 +45,19 @@ class ModelLoaderOutput(BaseInvocationOutput):
#fmt: on
class PipelineModelField(BaseModel):
"""Pipeline model field"""
class MainModelField(BaseModel):
"""Main model field"""
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
class PipelineModelLoaderInvocation(BaseInvocation):
"""Loads a pipeline model, outputting its submodels."""
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
type: Literal["pipeline_model_loader"] = "pipeline_model_loader"
type: Literal["main_model_loader"] = "main_model_loader"
model: PipelineModelField = Field(description="The model to load")
model: MainModelField = Field(description="The model to load")
# TODO: precision?
# Schema customisation

View File

@ -4,6 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { PartialAppConfig } from 'app/types/invokeai';
import ImageUploader from 'common/components/ImageUploader';
import GalleryDrawer from 'features/gallery/components/GalleryPanel';
import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal';
import Lightbox from 'features/lightbox/components/Lightbox';
import SiteHeader from 'features/system/components/SiteHeader';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
@ -15,11 +16,10 @@ import InvokeTabs from 'features/ui/components/InvokeTabs';
import ParametersDrawer from 'features/ui/components/ParametersDrawer';
import i18n from 'i18n';
import { ReactNode, memo, useEffect } from 'react';
import DeleteBoardImagesModal from '../../features/gallery/components/Boards/DeleteBoardImagesModal';
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
import GlobalHotkeys from './GlobalHotkeys';
import Toaster from './Toaster';
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
import DeleteBoardImagesModal from '../../features/gallery/components/Boards/DeleteBoardImagesModal';
import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal';
const DEFAULT_CONFIG = {};

View File

@ -6,13 +6,13 @@ import {
ModelInputFieldValue,
} from 'features/nodes/types/types';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { FieldComponentProps } from './types';
import { forEach, isString } from 'lodash-es';
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
import { forEach, isString } from 'lodash-es';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useListModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
const ModelInputFieldComponent = (
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
@ -22,18 +22,18 @@ const ModelInputFieldComponent = (
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data: pipelineModels } = useListModelsQuery({
const { data: mainModels } = useListModelsQuery({
model_type: 'main',
});
const data = useMemo(() => {
if (!pipelineModels) {
if (!mainModels) {
return [];
}
const data: SelectItem[] = [];
forEach(pipelineModels.entities, (model, id) => {
forEach(mainModels.entities, (model, id) => {
if (!model) {
return;
}
@ -46,11 +46,11 @@ const ModelInputFieldComponent = (
});
return data;
}, [pipelineModels]);
}, [mainModels]);
const selectedModel = useMemo(
() => pipelineModels?.entities[field.value ?? pipelineModels.ids[0]],
[pipelineModels?.entities, pipelineModels?.ids, field.value]
() => mainModels?.entities[field.value ?? mainModels.ids[0]],
[mainModels?.entities, mainModels?.ids, field.value]
);
const handleValueChanged = useCallback(
@ -71,18 +71,18 @@ const ModelInputFieldComponent = (
);
useEffect(() => {
if (field.value && pipelineModels?.ids.includes(field.value)) {
if (field.value && mainModels?.ids.includes(field.value)) {
return;
}
const firstModel = pipelineModels?.ids[0];
const firstModel = mainModels?.ids[0];
if (!isString(firstModel)) {
return;
}
handleValueChanged(firstModel);
}, [field.value, handleValueChanged, pipelineModels?.ids]);
}, [field.value, handleValueChanged, mainModels?.ids]);
return (
<IAIMantineSelect

View File

@ -1,31 +1,25 @@
import { log } from 'app/logging/useLogger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
ImageDTO,
ImageResizeInvocation,
ImageToLatentsInvocation,
RandomIntInvocation,
RangeOfSizeInvocation,
} from 'services/api/types';
import { NonNullableGraph } from 'features/nodes/types/types';
import { log } from 'app/logging/useLogger';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import {
ITERATE,
IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS,
LATENTS_TO_IMAGE,
PIPELINE_MODEL_LOADER,
LATENTS_TO_LATENTS,
MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS,
LATENTS_TO_LATENTS,
RESIZE,
} from './constants';
import { set } from 'lodash-es';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
const moduleLog = log.child({ namespace: 'nodes' });
@ -52,7 +46,7 @@ export const buildCanvasImageToImageGraph = (
// The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions;
const model = modelIdToPipelineModelField(modelId);
const model = modelIdToMainModelField(modelId);
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
@ -81,9 +75,9 @@ export const buildCanvasImageToImageGraph = (
type: 'noise',
id: NOISE,
},
[PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader',
id: PIPELINE_MODEL_LOADER,
[MAIN_MODEL_LOADER]: {
type: 'main_model_loader',
id: MAIN_MODEL_LOADER,
model,
},
[LATENTS_TO_IMAGE]: {
@ -110,7 +104,7 @@ export const buildCanvasImageToImageGraph = (
edges: [
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -120,7 +114,7 @@ export const buildCanvasImageToImageGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -130,7 +124,7 @@ export const buildCanvasImageToImageGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'vae',
},
destination: {
@ -170,7 +164,7 @@ export const buildCanvasImageToImageGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'vae',
},
destination: {
@ -180,7 +174,7 @@ export const buildCanvasImageToImageGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'unet',
},
destination: {

View File

@ -1,23 +1,23 @@
import { log } from 'app/logging/useLogger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
ImageDTO,
InpaintInvocation,
RandomIntInvocation,
RangeOfSizeInvocation,
} from 'services/api/types';
import { NonNullableGraph } from 'features/nodes/types/types';
import { log } from 'app/logging/useLogger';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import {
INPAINT,
INPAINT_GRAPH,
ITERATE,
PIPELINE_MODEL_LOADER,
MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING,
POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
INPAINT_GRAPH,
INPAINT,
} from './constants';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
const moduleLog = log.child({ namespace: 'nodes' });
@ -55,7 +55,7 @@ export const buildCanvasInpaintGraph = (
// We may need to set the inpaint width and height to scale the image
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const model = modelIdToPipelineModelField(modelId);
const model = modelIdToMainModelField(modelId);
const graph: NonNullableGraph = {
id: INPAINT_GRAPH,
@ -101,9 +101,9 @@ export const buildCanvasInpaintGraph = (
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
},
[PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader',
id: PIPELINE_MODEL_LOADER,
[MAIN_MODEL_LOADER]: {
type: 'main_model_loader',
id: MAIN_MODEL_LOADER,
model,
},
[RANGE_OF_SIZE]: {
@ -142,7 +142,7 @@ export const buildCanvasInpaintGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -152,7 +152,7 @@ export const buildCanvasInpaintGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -162,7 +162,7 @@ export const buildCanvasInpaintGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'unet',
},
destination: {
@ -172,7 +172,7 @@ export const buildCanvasInpaintGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'vae',
},
destination: {

View File

@ -1,21 +1,17 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api/types';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import {
ITERATE,
LATENTS_TO_IMAGE,
PIPELINE_MODEL_LOADER,
MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
TEXT_TO_IMAGE_GRAPH,
TEXT_TO_LATENTS,
} from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
/**
* Builds the Canvas tab's Text to Image graph.
@ -38,7 +34,7 @@ export const buildCanvasTextToImageGraph = (
// The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions;
const model = modelIdToPipelineModelField(modelId);
const model = modelIdToMainModelField(modelId);
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
@ -76,9 +72,9 @@ export const buildCanvasTextToImageGraph = (
scheduler,
steps,
},
[PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader',
id: PIPELINE_MODEL_LOADER,
[MAIN_MODEL_LOADER]: {
type: 'main_model_loader',
id: MAIN_MODEL_LOADER,
model,
},
[LATENTS_TO_IMAGE]: {
@ -109,7 +105,7 @@ export const buildCanvasTextToImageGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -119,7 +115,7 @@ export const buildCanvasTextToImageGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -129,7 +125,7 @@ export const buildCanvasTextToImageGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'unet',
},
destination: {
@ -149,7 +145,7 @@ export const buildCanvasTextToImageGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'vae',
},
destination: {

View File

@ -1,28 +1,28 @@
import { log } from 'app/logging/useLogger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
ImageCollectionInvocation,
ImageResizeInvocation,
ImageToLatentsInvocation,
IterateInvocation,
} from 'services/api/types';
import { NonNullableGraph } from 'features/nodes/types/types';
import { log } from 'app/logging/useLogger';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import {
IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS,
LATENTS_TO_IMAGE,
PIPELINE_MODEL_LOADER,
LATENTS_TO_LATENTS,
MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS,
LATENTS_TO_LATENTS,
RESIZE,
IMAGE_COLLECTION,
IMAGE_COLLECTION_ITERATE,
} from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
const moduleLog = log.child({ namespace: 'nodes' });
@ -69,7 +69,7 @@ export const buildLinearImageToImageGraph = (
throw new Error('No initial image found in state');
}
const model = modelIdToPipelineModelField(modelId);
const model = modelIdToMainModelField(modelId);
// copy-pasted graph from node editor, filled in with state values & friendly node ids
const graph: NonNullableGraph = {
@ -89,9 +89,9 @@ export const buildLinearImageToImageGraph = (
type: 'noise',
id: NOISE,
},
[PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader',
id: PIPELINE_MODEL_LOADER,
[MAIN_MODEL_LOADER]: {
type: 'main_model_loader',
id: MAIN_MODEL_LOADER,
model,
},
[LATENTS_TO_IMAGE]: {
@ -118,7 +118,7 @@ export const buildLinearImageToImageGraph = (
edges: [
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -128,7 +128,7 @@ export const buildLinearImageToImageGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -138,7 +138,7 @@ export const buildLinearImageToImageGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'vae',
},
destination: {
@ -178,7 +178,7 @@ export const buildLinearImageToImageGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'vae',
},
destination: {
@ -188,7 +188,7 @@ export const buildLinearImageToImageGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'unet',
},
destination: {

View File

@ -1,17 +1,17 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import {
LATENTS_TO_IMAGE,
PIPELINE_MODEL_LOADER,
MAIN_MODEL_LOADER,
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
TEXT_TO_IMAGE_GRAPH,
TEXT_TO_LATENTS,
} from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
export const buildLinearTextToImageGraph = (
state: RootState
@ -27,7 +27,7 @@ export const buildLinearTextToImageGraph = (
height,
} = state.generation;
const model = modelIdToPipelineModelField(modelId);
const model = modelIdToMainModelField(modelId);
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
@ -65,9 +65,9 @@ export const buildLinearTextToImageGraph = (
scheduler,
steps,
},
[PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader',
id: PIPELINE_MODEL_LOADER,
[MAIN_MODEL_LOADER]: {
type: 'main_model_loader',
id: MAIN_MODEL_LOADER,
model,
},
[LATENTS_TO_IMAGE]: {
@ -98,7 +98,7 @@ export const buildLinearTextToImageGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -108,7 +108,7 @@ export const buildLinearTextToImageGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'clip',
},
destination: {
@ -118,7 +118,7 @@ export const buildLinearTextToImageGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'unet',
},
destination: {
@ -138,7 +138,7 @@ export const buildLinearTextToImageGraph = (
},
{
source: {
node_id: PIPELINE_MODEL_LOADER,
node_id: MAIN_MODEL_LOADER,
field: 'vae',
},
destination: {

View File

@ -1,10 +1,10 @@
import { Graph } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
import { cloneDeep, omit, reduce } from 'lodash-es';
import { RootState } from 'app/store/store';
import { InputFieldValue } from 'features/nodes/types/types';
import { cloneDeep, omit, reduce } from 'lodash-es';
import { Graph } from 'services/api/types';
import { AnyInvocation } from 'services/events/types';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { v4 as uuidv4 } from 'uuid';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
/**
* We need to do special handling for some fields
@ -27,7 +27,7 @@ export const parseFieldValue = (field: InputFieldValue) => {
if (field.type === 'model') {
if (field.value) {
return modelIdToPipelineModelField(field.value);
return modelIdToMainModelField(field.value);
}
}

View File

@ -7,7 +7,7 @@ export const NOISE = 'noise';
export const RANDOM_INT = 'rand_int';
export const RANGE_OF_SIZE = 'range_of_size';
export const ITERATE = 'iterate';
export const PIPELINE_MODEL_LOADER = 'pipeline_model_loader';
export const MAIN_MODEL_LOADER = 'main_model_loader';
export const IMAGE_TO_LATENTS = 'image_to_latents';
export const LATENTS_TO_LATENTS = 'latents_to_latents';
export const RESIZE = 'resize_image';

View File

@ -0,0 +1,16 @@
import { BaseModelType, MainModelField } from 'services/api/types';
/**
* Crudely converts a model id to a main model field
* TODO: Make better
*/
export const modelIdToMainModelField = (modelId: string): MainModelField => {
const [base_model, model_type, model_name] = modelId.split('/');
const field: MainModelField = {
base_model: base_model as BaseModelType,
model_name,
};
return field;
};

View File

@ -1,18 +0,0 @@
import { BaseModelType, PipelineModelField } from 'services/api/types';
/**
* Crudely converts a model id to a pipeline model field
* TODO: Make better
*/
export const modelIdToPipelineModelField = (
modelId: string
): PipelineModelField => {
const [base_model, model_type, model_name] = modelId.split('/');
const field: PipelineModelField = {
base_model: base_model as BaseModelType,
model_name,
};
return field;
};

View File

@ -5,9 +5,9 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { modelSelected } from 'features/parameters/store/generationSlice';
import { forEach, isString } from 'lodash-es';
import { SelectItem } from '@mantine/core';
import { RootState } from 'app/store/store';
import { forEach, isString } from 'lodash-es';
import { useListModelsQuery } from 'services/api/endpoints/models';
export const MODEL_TYPE_MAP = {
@ -23,18 +23,18 @@ const ModelSelect = () => {
(state: RootState) => state.generation.model
);
const { data: pipelineModels, isLoading } = useListModelsQuery({
const { data: mainModels, isLoading } = useListModelsQuery({
model_type: 'main',
});
const data = useMemo(() => {
if (!pipelineModels) {
if (!mainModels) {
return [];
}
const data: SelectItem[] = [];
forEach(pipelineModels.entities, (model, id) => {
forEach(mainModels.entities, (model, id) => {
if (!model) {
return;
}
@ -47,11 +47,11 @@ const ModelSelect = () => {
});
return data;
}, [pipelineModels]);
}, [mainModels]);
const selectedModel = useMemo(
() => pipelineModels?.entities[selectedModelId],
[pipelineModels?.entities, selectedModelId]
() => mainModels?.entities[selectedModelId],
[mainModels?.entities, selectedModelId]
);
const handleChangeModel = useCallback(
@ -65,20 +65,18 @@ const ModelSelect = () => {
);
useEffect(() => {
// If the selected model is not in the list of models, select the first one
// Handles first-run setting of models, and the user deleting the previously-selected model
if (selectedModelId && pipelineModels?.ids.includes(selectedModelId)) {
if (selectedModelId && mainModels?.ids.includes(selectedModelId)) {
return;
}
const firstModel = pipelineModels?.ids[0];
const firstModel = mainModels?.ids[0];
if (!isString(firstModel)) {
return;
}
handleChangeModel(firstModel);
}, [handleChangeModel, pipelineModels?.ids, selectedModelId]);
}, [handleChangeModel, mainModels?.ids, selectedModelId]);
return isLoading ? (
<IAIMantineSelect

View File

@ -8,8 +8,8 @@ import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
import ModelList from './ModelManagerPanel/ModelList';
export default function ModelManagerPanel() {
const { data: pipelineModels } = useListModelsQuery({
model_type: 'pipeline',
const { data: mainModels } = useListModelsQuery({
model_type: 'main',
});
const openModel = useAppSelector(
@ -17,20 +17,20 @@ export default function ModelManagerPanel() {
);
const renderModelEditTabs = () => {
if (!openModel || !pipelineModels) return;
if (!openModel || !mainModels) return;
if (pipelineModels['entities'][openModel]['model_format'] === 'diffusers') {
if (mainModels['entities'][openModel]['model_format'] === 'diffusers') {
return (
<DiffusersModelEdit
modelToEdit={openModel}
retrievedModel={pipelineModels['entities'][openModel]}
retrievedModel={mainModels['entities'][openModel]}
/>
);
} else {
return (
<CheckpointModelEdit
modelToEdit={openModel}
retrievedModel={pipelineModels['entities'][openModel]}
retrievedModel={mainModels['entities'][openModel]}
/>
);
}

View File

@ -36,8 +36,8 @@ function ModelFilterButton({
}
const ModelList = () => {
const { data: pipelineModels } = useListModelsQuery({
model_type: 'pipeline',
const { data: mainModels } = useListModelsQuery({
model_type: 'main',
});
const [renderModelList, setRenderModelList] = React.useState<boolean>(false);
@ -70,9 +70,9 @@ const ModelList = () => {
const filteredModelListItemsToRender: ReactNode[] = [];
const localFilteredModelListItemsToRender: ReactNode[] = [];
if (!pipelineModels) return;
if (!mainModels) return;
const modelList = pipelineModels.entities;
const modelList = mainModels.entities;
Object.keys(modelList).forEach((model, i) => {
if (
@ -179,7 +179,7 @@ const ModelList = () => {
)}
</Flex>
);
}, [pipelineModels, searchText, t, isSelectedFilter]);
}, [mainModels, searchText, t, isSelectedFilter]);
return (
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">

File diff suppressed because it is too large Load Diff

View File

@ -33,7 +33,7 @@ export type OffsetPaginatedResults_ImageDTO_ =
// Models
export type ModelType = S<'ModelType'>;
export type BaseModelType = S<'BaseModelType'>;
export type PipelineModelField = S<'PipelineModelField'>;
export type MainModelField = S<'MainModelField'>;
export type ModelsList = S<'ModelsList'>;
// Graphs
@ -57,8 +57,8 @@ export type TextToLatentsInvocation = N<'TextToLatentsInvocation'>;
export type LatentsToLatentsInvocation = N<'LatentsToLatentsInvocation'>;
export type ImageToLatentsInvocation = N<'ImageToLatentsInvocation'>;
export type LatentsToImageInvocation = N<'LatentsToImageInvocation'>;
export type PipelineModelLoaderInvocation = N<'PipelineModelLoaderInvocation'>;
export type ImageCollectionInvocation = N<'ImageCollectionInvocation'>;
export type MainModelLoaderInvocation = N<'MainModelLoaderInvocation'>;
// ControlNet Nodes
export type ControlNetInvocation = N<'ControlNetInvocation'>;