This commit is contained in:
Mary Hipp 2024-02-23 16:12:49 -05:00 committed by psychedelicious
parent 07fb5d5c19
commit 974658107d
18 changed files with 157 additions and 155 deletions

View File

@ -1,3 +1,4 @@
import { api } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models';
import {
socketModelInstallCompleted,
@ -6,7 +7,6 @@ import {
} from 'services/events/actions';
import { startAppListening } from '../..';
import { api } from '../../../../../../services/api';
export const addModelInstallEventListener = () => {
startAppListening({
@ -41,7 +41,7 @@ export const addModelInstallEventListener = () => {
return draft;
})
);
dispatch(api.util.invalidateTags([{ type: "ModelConfig" }]))
dispatch(api.util.invalidateTags([{ type: 'ModelConfig' }]));
},
});
@ -55,7 +55,7 @@ export const addModelInstallEventListener = () => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'error';
modelImport.error_reason = error_type
modelImport.error_reason = error_type;
}
return draft;
})

View File

@ -91,7 +91,7 @@ VAEMetadataItem.displayName = 'VAEMetadataItem';
type ModelMetadataItemProps = {
label: string;
modelKey?: string;
extra?: string;
onClick: () => void;
};

View File

@ -1,54 +1,62 @@
import { useCallback } from "react";
import { ALL_BASE_MODELS } from "../../../services/api/constants";
import { useGetMainModelsQuery, useGetLoRAModelsQuery, useGetTextualInversionModelsQuery, useGetControlNetModelsQuery, useGetT2IAdapterModelsQuery, useGetIPAdapterModelsQuery, useGetVaeModelsQuery, } from "../../../services/api/endpoints/models";
import { EntityState } from "@reduxjs/toolkit";
import { forEach } from "lodash-es";
import { AnyModelConfig } from "../../../services/api/types";
import type { EntityState } from '@reduxjs/toolkit';
import { forEach } from 'lodash-es';
import { useCallback } from 'react';
import { ALL_BASE_MODELS } from 'services/api/constants';
import {
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetLoRAModelsQuery,
useGetMainModelsQuery,
useGetT2IAdapterModelsQuery,
useGetTextualInversionModelsQuery,
useGetVaeModelsQuery,
} from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
export const useIsImported = () => {
const { data: mainModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
const { data: loras } = useGetLoRAModelsQuery();
const { data: embeddings } = useGetTextualInversionModelsQuery();
const { data: controlnets } = useGetControlNetModelsQuery();
const { data: ipAdapters } = useGetIPAdapterModelsQuery();
const { data: t2is } = useGetT2IAdapterModelsQuery();
const { data: vaes } = useGetVaeModelsQuery();
const { data: mainModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
const { data: loras } = useGetLoRAModelsQuery();
const { data: embeddings } = useGetTextualInversionModelsQuery();
const { data: controlnets } = useGetControlNetModelsQuery();
const { data: ipAdapters } = useGetIPAdapterModelsQuery();
const { data: t2is } = useGetT2IAdapterModelsQuery();
const { data: vaes } = useGetVaeModelsQuery();
const isImported = useCallback(({ name }: { name: string }) => {
const data = [mainModels, loras, embeddings, controlnets, ipAdapters, t2is, vaes]
let isMatch = false;
for (let index = 0; index < data.length; index++) {
const modelType: EntityState<AnyModelConfig, string> | undefined = data[index];
const isImported = useCallback(
({ name }: { name: string }) => {
const data = [mainModels, loras, embeddings, controlnets, ipAdapters, t2is, vaes];
let isMatch = false;
for (let index = 0; index < data.length; index++) {
const modelType: EntityState<AnyModelConfig, string> | undefined = data[index];
const match = modelsFilter(modelType, name)
const match = modelsFilter(modelType, name);
if (!!match.length) {
isMatch = true
break;
}
if (match.length) {
isMatch = true;
break;
}
return isMatch
}, [mainModels, loras, embeddings, controlnets, ipAdapters, t2is, vaes])
}
return isMatch;
},
[mainModels, loras, embeddings, controlnets, ipAdapters, t2is, vaes]
);
return { isImported }
}
return { isImported };
};
const modelsFilter = <T extends AnyModelConfig>(
data: EntityState<T, string> | undefined,
nameFilter: string,
): T[] => {
const filteredModels: T[] = [];
const modelsFilter = <T extends AnyModelConfig>(data: EntityState<T, string> | undefined, nameFilter: string): T[] => {
const filteredModels: T[] = [];
forEach(data?.entities, (model) => {
if (!model) {
return;
}
forEach(data?.entities, (model) => {
if (!model) {
return;
}
const matchesFilter = model.path.toLowerCase().includes(nameFilter.toLowerCase());
const matchesFilter = model.path.toLowerCase().includes(nameFilter.toLowerCase());
if (matchesFilter) {
filteredModels.push(model);
}
});
return filteredModels;
};
if (matchesFilter) {
filteredModels.push(model);
}
});
return filteredModels;
};

View File

@ -2,59 +2,59 @@ import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
type ModelManagerState = {
_version: 1;
selectedModelKey: string | null;
selectedModelMode: "edit" | "view",
searchTerm: string;
filteredModelType: string | null;
_version: 1;
selectedModelKey: string | null;
selectedModelMode: 'edit' | 'view';
searchTerm: string;
filteredModelType: string | null;
};
export const initialModelManagerState: ModelManagerState = {
_version: 1,
selectedModelKey: null,
selectedModelMode: "view",
filteredModelType: null,
searchTerm: ""
_version: 1,
selectedModelKey: null,
selectedModelMode: 'view',
filteredModelType: null,
searchTerm: '',
};
export const modelManagerV2Slice = createSlice({
name: 'modelmanagerV2',
initialState: initialModelManagerState,
reducers: {
setSelectedModelKey: (state, action: PayloadAction<string | null>) => {
state.selectedModelMode = "view"
state.selectedModelKey = action.payload;
},
setSelectedModelMode: (state, action: PayloadAction<"view" | "edit">) => {
state.selectedModelMode = action.payload;
},
setSearchTerm: (state, action: PayloadAction<string>) => {
state.searchTerm = action.payload;
},
setFilteredModelType: (state, action: PayloadAction<string | null>) => {
state.filteredModelType = action.payload;
},
name: 'modelmanagerV2',
initialState: initialModelManagerState,
reducers: {
setSelectedModelKey: (state, action: PayloadAction<string | null>) => {
state.selectedModelMode = 'view';
state.selectedModelKey = action.payload;
},
setSelectedModelMode: (state, action: PayloadAction<'view' | 'edit'>) => {
state.selectedModelMode = action.payload;
},
setSearchTerm: (state, action: PayloadAction<string>) => {
state.searchTerm = action.payload;
},
setFilteredModelType: (state, action: PayloadAction<string | null>) => {
state.filteredModelType = action.payload;
},
},
});
export const { setSelectedModelKey, setSearchTerm, setFilteredModelType, setSelectedModelMode } = modelManagerV2Slice.actions;
export const { setSelectedModelKey, setSearchTerm, setFilteredModelType, setSelectedModelMode } =
modelManagerV2Slice.actions;
export const selectModelManagerSlice = (state: RootState) => state.modelmanager;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export const migrateModelManagerState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const modelManagerV2PersistConfig: PersistConfig<ModelManagerState> = {
name: modelManagerV2Slice.name,
initialState: initialModelManagerState,
migrate: migrateModelManagerState,
persistDenylist: [],
name: modelManagerV2Slice.name,
initialState: initialModelManagerState,
migrate: migrateModelManagerState,
persistDenylist: [],
};

View File

@ -27,9 +27,9 @@ const BaseModelSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>)
);
return (
<FormControl>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
<Combobox value={value} options={options} onChange={onChange} />
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
<Combobox value={value} options={options} onChange={onChange} />
</Flex>
</FormControl>
);

View File

@ -1,7 +1,7 @@
import { Badge, Tooltip } from '@invoke-ai/ui-library';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { ModelInstallStatus } from '../../../../../services/api/types';
import type { ModelInstallStatus } from 'services/api/types';
const STATUSES = {
waiting: { colorScheme: 'cyan', translationKey: 'queue.pending' },

View File

@ -1,4 +1,4 @@
import { Box, Flex, IconButton, Progress, Tag, Text, Tooltip } from '@invoke-ai/ui-library';
import { Box, Flex, IconButton, Progress, Text, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
@ -6,7 +6,8 @@ import { t } from 'i18next';
import { useCallback, useMemo } from 'react';
import { PiXBold } from 'react-icons/pi';
import { useDeleteModelImportMutation } from 'services/api/endpoints/models';
import type { ModelInstallJob, HFModelSource, LocalModelSource, URLModelSource } from 'services/api/types';
import type { HFModelSource, LocalModelSource, ModelInstallJob, URLModelSource } from 'services/api/types';
import ImportQueueBadge from './ImportQueueBadge';
type ModelListItemProps = {

View File

@ -1,12 +1,12 @@
import { Flex, Text, Box, Button, IconButton, Tooltip, Badge } from '@invoke-ai/ui-library';
import { Badge, Box, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useIsImported } from 'features/modelManagerV2/hooks/useIsImported';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { IoAdd } from 'react-icons/io5';
import { useAppDispatch } from '../../../../../app/store/storeHooks';
import { useImportMainModelsMutation } from '../../../../../services/api/endpoints/models';
import { addToast } from '../../../../system/store/systemSlice';
import { makeToast } from '../../../../system/util/makeToast';
import { useIsImported } from '../../../hooks/useIsImported';
import { useMemo } from 'react';
import { useImportMainModelsMutation } from 'services/api/endpoints/models';
export const ScanModelResultItem = ({ result }: { result: string }) => {
const { t } = useTranslation();
@ -14,11 +14,11 @@ export const ScanModelResultItem = ({ result }: { result: string }) => {
const { isImported } = useIsImported();
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
const [importMainModel] = useImportMainModelsMutation();
const isAlreadyImported = useMemo(() => {
const prettyName = result.split('\\').slice(-1)[0];
console.log({ prettyName });
if (prettyName) {
return isImported({ name: prettyName });
} else {
@ -26,7 +26,7 @@ export const ScanModelResultItem = ({ result }: { result: string }) => {
}
}, [result, isImported]);
const handleQuickAdd = () => {
const handleQuickAdd = useCallback(() => {
importMainModel({ source: result, config: undefined })
.unwrap()
.then((_) => {
@ -51,10 +51,10 @@ export const ScanModelResultItem = ({ result }: { result: string }) => {
);
}
});
};
}, [importMainModel, result, dispatch, t]);
return (
<Flex justifyContent={'space-between'}>
<Flex justifyContent="space-between">
<Flex fontSize="sm" flexDir="column">
<Text fontWeight="semibold">{result.split('\\').slice(-1)[0]}</Text>
<Text variant="subtext">{result}</Text>

View File

@ -1,4 +1,3 @@
export const ScanModels = () => {
return null;
};

View File

@ -14,13 +14,11 @@ export const ScanModelsForm = () => {
const [_scanModels, { isLoading, data }] = useLazyScanModelsQuery();
const handleSubmitScan = useCallback(async () => {
try {
await _scanModels({ scan_path: scanPath }).unwrap();
} catch (error: any) {
_scanModels({ scan_path: scanPath }).catch((error) => {
if (error) {
setErrorMessage(error.data.detail);
}
}
});
}, [_scanModels, scanPath]);
const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {

View File

@ -1,9 +1,10 @@
import { Divider, Flex, Heading, IconButton, Input, InputGroup, InputRightElement, Text } from '@invoke-ai/ui-library';
import { Divider, Flex, Heading, IconButton, Input, InputGroup, InputRightElement } from '@invoke-ai/ui-library';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { t } from 'i18next';
import type { ChangeEventHandler } from 'react';
import { useCallback, useMemo, useState } from 'react';
import { PiXBold } from 'react-icons/pi';
import { ScanModelResultItem } from './ScanModelResultItem';
export const ScanModelsResults = ({ results }: { results: string[] }) => {
@ -16,12 +17,9 @@ export const ScanModelsResults = ({ results }: { results: string[] }) => {
});
}, [results, searchTerm]);
const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback(
(e) => {
setSearchTerm(e.target.value);
},
[results]
);
const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
setSearchTerm(e.target.value);
}, []);
const clearSearch = useCallback(() => {
setSearchTerm('');

View File

@ -1,10 +1,10 @@
import { Button,Flex, FormControl, FormLabel, Input } from '@invoke-ai/ui-library';
import { Button, Flex, FormControl, FormLabel, Input } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { useCallback } from 'react';
import type { SubmitHandler} from 'react-hook-form';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useImportMainModelsMutation } from 'services/api/endpoints/models';

View File

@ -2,9 +2,8 @@ import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@in
import { AdvancedImport } from './AddModelPanel/AdvancedImport';
import { ImportQueue } from './AddModelPanel/ImportQueue/ImportQueue';
import { ScanModels } from './AddModelPanel/ScanModels/ScanModels';
import { SimpleImport } from './AddModelPanel/SimpleImport';
import { ScanModelsForm } from './AddModelPanel/ScanModels/ScanModelsForm';
import { SimpleImport } from './AddModelPanel/SimpleImport';
export const ImportModels = () => {
return (

View File

@ -1,8 +1,8 @@
import { Flex, IconButton,Input, InputGroup, InputRightElement, Spacer } from '@invoke-ai/ui-library';
import { Flex, IconButton, Input, InputGroup, InputRightElement, Spacer } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setSearchTerm } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { t } from 'i18next';
import type { ChangeEventHandler} from 'react';
import type { ChangeEventHandler } from 'react';
import { useCallback } from 'react';
import { PiXBold } from 'react-icons/pi';

View File

@ -59,7 +59,7 @@ export const ModelConvert = (props: ModelConvertProps) => {
)
);
});
}, [convertModel, dispatch, model.base, model.name, t]);
}, [convertModel, dispatch, model.key, model.name, t]);
return (
<>

View File

@ -58,7 +58,6 @@ type DeleteImportModelsResponse =
type PruneModelImportsResponse =
paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json'];
export type ScanFolderResponse =
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
type ScanFolderArg = operations['scan_for_models']['parameters']['query'];
@ -104,25 +103,25 @@ export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined
const buildProvidesTags =
<TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) =>
(result: EntityState<TEntity, string> | undefined) => {
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: tagType,
id,
}))
);
}
(result: EntityState<TEntity, string> | undefined) => {
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: tagType,
id,
}))
);
}
return tags;
};
return tags;
};
const buildTransformResponse =
<T extends AnyModelConfig>(adapter: EntityAdapter<T, string>) =>
(response: { models: T[] }) => {
return adapter.setAll(adapter.getInitialState(), response.models);
};
(response: { models: T[] }) => {
return adapter.setAll(adapter.getInitialState(), response.models);
};
/**
* Builds an endpoint URL for the models router

View File

@ -117,8 +117,8 @@ export const isRefinerMainModelModelConfig = (config: AnyModelConfig): config is
export type MergeModelConfig = S['Body_merge'];
export type ImportModelConfig = S['Body_import_model'];
export type ModelInstallJob = S['ModelInstallJob']
export type ModelInstallStatus = S["InstallStatus"]
export type ModelInstallJob = S['ModelInstallJob'];
export type ModelInstallStatus = S['InstallStatus'];
export type HFModelSource = S['HFModelSource'];
export type CivitaiModelSource = S['CivitaiModelSource'];

View File

@ -146,29 +146,29 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
data,
})
);
})
});
/**
/**
* Model Install Completed
*/
socket.on('model_install_completed', (data) => {
dispatch(
socketModelInstallCompleted({
data,
})
);
})
socket.on('model_install_completed', (data) => {
dispatch(
socketModelInstallCompleted({
data,
})
);
});
/**
* Model Install Error
*/
socket.on('model_install_error', (data) => {
dispatch(
socketModelInstallError({
data,
})
);
})
socket.on('model_install_error', (data) => {
dispatch(
socketModelInstallError({
data,
})
);
});
/**
* Session retrieval error