mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Add search to Scanned Models
This commit is contained in:
parent
98e6a56714
commit
41e7b008fb
@ -3,10 +3,13 @@ import { makeToast } from 'app/components/Toaster';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { difference, map, values } from 'lodash-es';
|
||||
import { MouseEvent, useCallback } from 'react';
|
||||
import { difference, forEach, map, values } from 'lodash-es';
|
||||
import { ChangeEvent, MouseEvent, useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
SearchFolderResponse,
|
||||
useGetMainModelsQuery,
|
||||
useGetModelsInFolderQuery,
|
||||
useImportMainModelsMutation,
|
||||
@ -17,22 +20,35 @@ export default function FoundModelsList() {
|
||||
const searchFolder = useAppSelector(
|
||||
(state: RootState) => state.modelmanager.searchFolder
|
||||
);
|
||||
|
||||
// Get all model paths from a given directory
|
||||
const { data: foundModels } = useGetModelsInFolderQuery({
|
||||
search_path: searchFolder ? searchFolder : '',
|
||||
});
|
||||
const [nameFilter, setNameFilter] = useState<string>('');
|
||||
|
||||
// Get paths of models that are already installed
|
||||
const { data: installedModels } = useGetMainModelsQuery();
|
||||
const installedModelValues = values(installedModels?.entities);
|
||||
const installedModelPaths = map(installedModelValues, 'path');
|
||||
|
||||
// Only select models those that aren't already installed to Invoke
|
||||
const notInstalledModels = difference(foundModels, installedModelPaths);
|
||||
// Get all model paths from a given directory
|
||||
const { foundModels, notInstalledModels, filteredModels } =
|
||||
useGetModelsInFolderQuery(
|
||||
{
|
||||
search_path: searchFolder ? searchFolder : '',
|
||||
},
|
||||
{
|
||||
selectFromResult: ({ data }) => {
|
||||
const installedModelValues = values(installedModels?.entities);
|
||||
const installedModelPaths = map(installedModelValues, 'path');
|
||||
// Only select models those that aren't already installed to Invoke
|
||||
const notInstalledModels = difference(data, installedModelPaths);
|
||||
return {
|
||||
foundModels: data,
|
||||
notInstalledModels: notInstalledModels,
|
||||
filteredModels: foundModelsFilter(notInstalledModels, nameFilter),
|
||||
};
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const quickAddHandler = useCallback(
|
||||
(e: MouseEvent<HTMLButtonElement>) => {
|
||||
@ -69,6 +85,10 @@ export default function FoundModelsList() {
|
||||
[dispatch, importMainModel]
|
||||
);
|
||||
|
||||
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setNameFilter(e.target.value);
|
||||
}, []);
|
||||
|
||||
const renderFoundModels = () => {
|
||||
if (!searchFolder) return;
|
||||
|
||||
@ -102,6 +122,10 @@ export default function FoundModelsList() {
|
||||
minW: '50%',
|
||||
}}
|
||||
>
|
||||
<IAIInput
|
||||
onChange={handleSearchFilter}
|
||||
label={t('modelManager.search')}
|
||||
/>
|
||||
<Flex p={2} gap={2}>
|
||||
<Text
|
||||
sx={{
|
||||
@ -119,7 +143,7 @@ export default function FoundModelsList() {
|
||||
</Text>
|
||||
</Flex>
|
||||
|
||||
{notInstalledModels.map((model) => (
|
||||
{filteredModels.map((model) => (
|
||||
<Flex
|
||||
sx={{
|
||||
p: 4,
|
||||
@ -172,3 +196,20 @@ export default function FoundModelsList() {
|
||||
|
||||
return renderFoundModels();
|
||||
}
|
||||
|
||||
const foundModelsFilter = (
|
||||
data: SearchFolderResponse | undefined,
|
||||
nameFilter: string
|
||||
) => {
|
||||
const filteredModels: SearchFolderResponse = [];
|
||||
forEach(data, (model) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (model.includes(nameFilter)) {
|
||||
filteredModels.push(model);
|
||||
}
|
||||
});
|
||||
return filteredModels;
|
||||
};
|
||||
|
@ -93,7 +93,7 @@ type AddMainModelArg = {
|
||||
type AddMainModelResponse =
|
||||
paths['/api/v1/models/add']['post']['responses']['201']['content']['application/json'];
|
||||
|
||||
type SearchFolderResponse =
|
||||
export type SearchFolderResponse =
|
||||
paths['/api/v1/models/search']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
type CheckpointConfigsResponse =
|
||||
|
Loading…
Reference in New Issue
Block a user