This commit is contained in:
Mary Hipp 2024-02-23 16:12:49 -05:00 committed by psychedelicious
parent 07fb5d5c19
commit 974658107d
18 changed files with 157 additions and 155 deletions

View File

@ -1,3 +1,4 @@
import { api } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models'; import { modelsApi } from 'services/api/endpoints/models';
import { import {
socketModelInstallCompleted, socketModelInstallCompleted,
@ -6,7 +7,6 @@ import {
} from 'services/events/actions'; } from 'services/events/actions';
import { startAppListening } from '../..'; import { startAppListening } from '../..';
import { api } from '../../../../../../services/api';
export const addModelInstallEventListener = () => { export const addModelInstallEventListener = () => {
startAppListening({ startAppListening({
@ -41,7 +41,7 @@ export const addModelInstallEventListener = () => {
return draft; return draft;
}) })
); );
dispatch(api.util.invalidateTags([{ type: "ModelConfig" }])) dispatch(api.util.invalidateTags([{ type: 'ModelConfig' }]));
}, },
}); });
@ -55,7 +55,7 @@ export const addModelInstallEventListener = () => {
const modelImport = draft.find((m) => m.id === id); const modelImport = draft.find((m) => m.id === id);
if (modelImport) { if (modelImport) {
modelImport.status = 'error'; modelImport.status = 'error';
modelImport.error_reason = error_type modelImport.error_reason = error_type;
} }
return draft; return draft;
}) })

View File

@ -91,7 +91,7 @@ VAEMetadataItem.displayName = 'VAEMetadataItem';
type ModelMetadataItemProps = { type ModelMetadataItemProps = {
label: string; label: string;
modelKey?: string; modelKey?: string;
extra?: string; extra?: string;
onClick: () => void; onClick: () => void;
}; };

View File

@ -1,54 +1,62 @@
import { useCallback } from "react"; import type { EntityState } from '@reduxjs/toolkit';
import { ALL_BASE_MODELS } from "../../../services/api/constants"; import { forEach } from 'lodash-es';
import { useGetMainModelsQuery, useGetLoRAModelsQuery, useGetTextualInversionModelsQuery, useGetControlNetModelsQuery, useGetT2IAdapterModelsQuery, useGetIPAdapterModelsQuery, useGetVaeModelsQuery, } from "../../../services/api/endpoints/models"; import { useCallback } from 'react';
import { EntityState } from "@reduxjs/toolkit"; import { ALL_BASE_MODELS } from 'services/api/constants';
import { forEach } from "lodash-es"; import {
import { AnyModelConfig } from "../../../services/api/types"; useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetLoRAModelsQuery,
useGetMainModelsQuery,
useGetT2IAdapterModelsQuery,
useGetTextualInversionModelsQuery,
useGetVaeModelsQuery,
} from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
export const useIsImported = () => { export const useIsImported = () => {
const { data: mainModels } = useGetMainModelsQuery(ALL_BASE_MODELS); const { data: mainModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
const { data: loras } = useGetLoRAModelsQuery(); const { data: loras } = useGetLoRAModelsQuery();
const { data: embeddings } = useGetTextualInversionModelsQuery(); const { data: embeddings } = useGetTextualInversionModelsQuery();
const { data: controlnets } = useGetControlNetModelsQuery(); const { data: controlnets } = useGetControlNetModelsQuery();
const { data: ipAdapters } = useGetIPAdapterModelsQuery(); const { data: ipAdapters } = useGetIPAdapterModelsQuery();
const { data: t2is } = useGetT2IAdapterModelsQuery(); const { data: t2is } = useGetT2IAdapterModelsQuery();
const { data: vaes } = useGetVaeModelsQuery(); const { data: vaes } = useGetVaeModelsQuery();
const isImported = useCallback(({ name }: { name: string }) => { const isImported = useCallback(
const data = [mainModels, loras, embeddings, controlnets, ipAdapters, t2is, vaes] ({ name }: { name: string }) => {
let isMatch = false; const data = [mainModels, loras, embeddings, controlnets, ipAdapters, t2is, vaes];
for (let index = 0; index < data.length; index++) { let isMatch = false;
const modelType: EntityState<AnyModelConfig, string> | undefined = data[index]; for (let index = 0; index < data.length; index++) {
const modelType: EntityState<AnyModelConfig, string> | undefined = data[index];
const match = modelsFilter(modelType, name) const match = modelsFilter(modelType, name);
if (!!match.length) { if (match.length) {
isMatch = true isMatch = true;
break; break;
}
} }
return isMatch }
}, [mainModels, loras, embeddings, controlnets, ipAdapters, t2is, vaes]) return isMatch;
},
[mainModels, loras, embeddings, controlnets, ipAdapters, t2is, vaes]
);
return { isImported } return { isImported };
} };
const modelsFilter = <T extends AnyModelConfig>( const modelsFilter = <T extends AnyModelConfig>(data: EntityState<T, string> | undefined, nameFilter: string): T[] => {
data: EntityState<T, string> | undefined, const filteredModels: T[] = [];
nameFilter: string,
): T[] => {
const filteredModels: T[] = [];
forEach(data?.entities, (model) => { forEach(data?.entities, (model) => {
if (!model) { if (!model) {
return; return;
} }
const matchesFilter = model.path.toLowerCase().includes(nameFilter.toLowerCase()); const matchesFilter = model.path.toLowerCase().includes(nameFilter.toLowerCase());
if (matchesFilter) { if (matchesFilter) {
filteredModels.push(model); filteredModels.push(model);
} }
}); });
return filteredModels; return filteredModels;
}; };

View File

@ -2,59 +2,59 @@ import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store'; import type { PersistConfig, RootState } from 'app/store/store';
type ModelManagerState = { type ModelManagerState = {
_version: 1; _version: 1;
selectedModelKey: string | null; selectedModelKey: string | null;
selectedModelMode: "edit" | "view", selectedModelMode: 'edit' | 'view';
searchTerm: string; searchTerm: string;
filteredModelType: string | null; filteredModelType: string | null;
}; };
export const initialModelManagerState: ModelManagerState = { export const initialModelManagerState: ModelManagerState = {
_version: 1, _version: 1,
selectedModelKey: null, selectedModelKey: null,
selectedModelMode: "view", selectedModelMode: 'view',
filteredModelType: null, filteredModelType: null,
searchTerm: "" searchTerm: '',
}; };
export const modelManagerV2Slice = createSlice({ export const modelManagerV2Slice = createSlice({
name: 'modelmanagerV2', name: 'modelmanagerV2',
initialState: initialModelManagerState, initialState: initialModelManagerState,
reducers: { reducers: {
setSelectedModelKey: (state, action: PayloadAction<string | null>) => { setSelectedModelKey: (state, action: PayloadAction<string | null>) => {
state.selectedModelMode = "view" state.selectedModelMode = 'view';
state.selectedModelKey = action.payload; state.selectedModelKey = action.payload;
},
setSelectedModelMode: (state, action: PayloadAction<"view" | "edit">) => {
state.selectedModelMode = action.payload;
},
setSearchTerm: (state, action: PayloadAction<string>) => {
state.searchTerm = action.payload;
},
setFilteredModelType: (state, action: PayloadAction<string | null>) => {
state.filteredModelType = action.payload;
},
}, },
setSelectedModelMode: (state, action: PayloadAction<'view' | 'edit'>) => {
state.selectedModelMode = action.payload;
},
setSearchTerm: (state, action: PayloadAction<string>) => {
state.searchTerm = action.payload;
},
setFilteredModelType: (state, action: PayloadAction<string | null>) => {
state.filteredModelType = action.payload;
},
},
}); });
export const { setSelectedModelKey, setSearchTerm, setFilteredModelType, setSelectedModelMode } = modelManagerV2Slice.actions; export const { setSelectedModelKey, setSearchTerm, setFilteredModelType, setSelectedModelMode } =
modelManagerV2Slice.actions;
export const selectModelManagerSlice = (state: RootState) => state.modelmanager; export const selectModelManagerSlice = (state: RootState) => state.modelmanager;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ /* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export const migrateModelManagerState = (state: any): any => { export const migrateModelManagerState = (state: any): any => {
if (!('_version' in state)) { if (!('_version' in state)) {
state._version = 1; state._version = 1;
} }
return state; return state;
}; };
export const modelManagerV2PersistConfig: PersistConfig<ModelManagerState> = { export const modelManagerV2PersistConfig: PersistConfig<ModelManagerState> = {
name: modelManagerV2Slice.name, name: modelManagerV2Slice.name,
initialState: initialModelManagerState, initialState: initialModelManagerState,
migrate: migrateModelManagerState, migrate: migrateModelManagerState,
persistDenylist: [], persistDenylist: [],
}; };

View File

@ -27,9 +27,9 @@ const BaseModelSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>)
); );
return ( return (
<FormControl> <FormControl>
<Flex direction="column" width="full"> <Flex direction="column" width="full">
<FormLabel>{t('modelManager.baseModel')}</FormLabel> <FormLabel>{t('modelManager.baseModel')}</FormLabel>
<Combobox value={value} options={options} onChange={onChange} /> <Combobox value={value} options={options} onChange={onChange} />
</Flex> </Flex>
</FormControl> </FormControl>
); );

View File

@ -1,7 +1,7 @@
import { Badge, Tooltip } from '@invoke-ai/ui-library'; import { Badge, Tooltip } from '@invoke-ai/ui-library';
import { memo } from 'react'; import { memo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { ModelInstallStatus } from '../../../../../services/api/types'; import type { ModelInstallStatus } from 'services/api/types';
const STATUSES = { const STATUSES = {
waiting: { colorScheme: 'cyan', translationKey: 'queue.pending' }, waiting: { colorScheme: 'cyan', translationKey: 'queue.pending' },

View File

@ -1,4 +1,4 @@
import { Box, Flex, IconButton, Progress, Tag, Text, Tooltip } from '@invoke-ai/ui-library'; import { Box, Flex, IconButton, Progress, Text, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast'; import { makeToast } from 'features/system/util/makeToast';
@ -6,7 +6,8 @@ import { t } from 'i18next';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { PiXBold } from 'react-icons/pi'; import { PiXBold } from 'react-icons/pi';
import { useDeleteModelImportMutation } from 'services/api/endpoints/models'; import { useDeleteModelImportMutation } from 'services/api/endpoints/models';
import type { ModelInstallJob, HFModelSource, LocalModelSource, URLModelSource } from 'services/api/types'; import type { HFModelSource, LocalModelSource, ModelInstallJob, URLModelSource } from 'services/api/types';
import ImportQueueBadge from './ImportQueueBadge'; import ImportQueueBadge from './ImportQueueBadge';
type ModelListItemProps = { type ModelListItemProps = {

View File

@ -1,12 +1,12 @@
import { Flex, Text, Box, Button, IconButton, Tooltip, Badge } from '@invoke-ai/ui-library'; import { Badge, Box, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useIsImported } from 'features/modelManagerV2/hooks/useIsImported';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { IoAdd } from 'react-icons/io5'; import { IoAdd } from 'react-icons/io5';
import { useAppDispatch } from '../../../../../app/store/storeHooks'; import { useImportMainModelsMutation } from 'services/api/endpoints/models';
import { useImportMainModelsMutation } from '../../../../../services/api/endpoints/models';
import { addToast } from '../../../../system/store/systemSlice';
import { makeToast } from '../../../../system/util/makeToast';
import { useIsImported } from '../../../hooks/useIsImported';
import { useMemo } from 'react';
export const ScanModelResultItem = ({ result }: { result: string }) => { export const ScanModelResultItem = ({ result }: { result: string }) => {
const { t } = useTranslation(); const { t } = useTranslation();
@ -14,11 +14,11 @@ export const ScanModelResultItem = ({ result }: { result: string }) => {
const { isImported } = useIsImported(); const { isImported } = useIsImported();
const [importMainModel, { isLoading }] = useImportMainModelsMutation(); const [importMainModel] = useImportMainModelsMutation();
const isAlreadyImported = useMemo(() => { const isAlreadyImported = useMemo(() => {
const prettyName = result.split('\\').slice(-1)[0]; const prettyName = result.split('\\').slice(-1)[0];
console.log({ prettyName });
if (prettyName) { if (prettyName) {
return isImported({ name: prettyName }); return isImported({ name: prettyName });
} else { } else {
@ -26,7 +26,7 @@ export const ScanModelResultItem = ({ result }: { result: string }) => {
} }
}, [result, isImported]); }, [result, isImported]);
const handleQuickAdd = () => { const handleQuickAdd = useCallback(() => {
importMainModel({ source: result, config: undefined }) importMainModel({ source: result, config: undefined })
.unwrap() .unwrap()
.then((_) => { .then((_) => {
@ -51,10 +51,10 @@ export const ScanModelResultItem = ({ result }: { result: string }) => {
); );
} }
}); });
}; }, [importMainModel, result, dispatch, t]);
return ( return (
<Flex justifyContent={'space-between'}> <Flex justifyContent="space-between">
<Flex fontSize="sm" flexDir="column"> <Flex fontSize="sm" flexDir="column">
<Text fontWeight="semibold">{result.split('\\').slice(-1)[0]}</Text> <Text fontWeight="semibold">{result.split('\\').slice(-1)[0]}</Text>
<Text variant="subtext">{result}</Text> <Text variant="subtext">{result}</Text>

View File

@ -1,4 +1,3 @@
export const ScanModels = () => { export const ScanModels = () => {
return null; return null;
}; };

View File

@ -14,13 +14,11 @@ export const ScanModelsForm = () => {
const [_scanModels, { isLoading, data }] = useLazyScanModelsQuery(); const [_scanModels, { isLoading, data }] = useLazyScanModelsQuery();
const handleSubmitScan = useCallback(async () => { const handleSubmitScan = useCallback(async () => {
try { _scanModels({ scan_path: scanPath }).catch((error) => {
await _scanModels({ scan_path: scanPath }).unwrap();
} catch (error: any) {
if (error) { if (error) {
setErrorMessage(error.data.detail); setErrorMessage(error.data.detail);
} }
} });
}, [_scanModels, scanPath]); }, [_scanModels, scanPath]);
const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback((e) => { const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {

View File

@ -1,9 +1,10 @@
import { Divider, Flex, Heading, IconButton, Input, InputGroup, InputRightElement, Text } from '@invoke-ai/ui-library'; import { Divider, Flex, Heading, IconButton, Input, InputGroup, InputRightElement } from '@invoke-ai/ui-library';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { t } from 'i18next'; import { t } from 'i18next';
import type { ChangeEventHandler } from 'react'; import type { ChangeEventHandler } from 'react';
import { useCallback, useMemo, useState } from 'react'; import { useCallback, useMemo, useState } from 'react';
import { PiXBold } from 'react-icons/pi'; import { PiXBold } from 'react-icons/pi';
import { ScanModelResultItem } from './ScanModelResultItem'; import { ScanModelResultItem } from './ScanModelResultItem';
export const ScanModelsResults = ({ results }: { results: string[] }) => { export const ScanModelsResults = ({ results }: { results: string[] }) => {
@ -16,12 +17,9 @@ export const ScanModelsResults = ({ results }: { results: string[] }) => {
}); });
}, [results, searchTerm]); }, [results, searchTerm]);
const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback( const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
(e) => { setSearchTerm(e.target.value);
setSearchTerm(e.target.value); }, []);
},
[results]
);
const clearSearch = useCallback(() => { const clearSearch = useCallback(() => {
setSearchTerm(''); setSearchTerm('');

View File

@ -1,10 +1,10 @@
import { Button,Flex, FormControl, FormLabel, Input } from '@invoke-ai/ui-library'; import { Button, Flex, FormControl, FormLabel, Input } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast'; import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next'; import { t } from 'i18next';
import { useCallback } from 'react'; import { useCallback } from 'react';
import type { SubmitHandler} from 'react-hook-form'; import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form'; import { useForm } from 'react-hook-form';
import { useImportMainModelsMutation } from 'services/api/endpoints/models'; import { useImportMainModelsMutation } from 'services/api/endpoints/models';

View File

@ -2,9 +2,8 @@ import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@in
import { AdvancedImport } from './AddModelPanel/AdvancedImport'; import { AdvancedImport } from './AddModelPanel/AdvancedImport';
import { ImportQueue } from './AddModelPanel/ImportQueue/ImportQueue'; import { ImportQueue } from './AddModelPanel/ImportQueue/ImportQueue';
import { ScanModels } from './AddModelPanel/ScanModels/ScanModels';
import { SimpleImport } from './AddModelPanel/SimpleImport';
import { ScanModelsForm } from './AddModelPanel/ScanModels/ScanModelsForm'; import { ScanModelsForm } from './AddModelPanel/ScanModels/ScanModelsForm';
import { SimpleImport } from './AddModelPanel/SimpleImport';
export const ImportModels = () => { export const ImportModels = () => {
return ( return (

View File

@ -1,8 +1,8 @@
import { Flex, IconButton,Input, InputGroup, InputRightElement, Spacer } from '@invoke-ai/ui-library'; import { Flex, IconButton, Input, InputGroup, InputRightElement, Spacer } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setSearchTerm } from 'features/modelManagerV2/store/modelManagerV2Slice'; import { setSearchTerm } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { t } from 'i18next'; import { t } from 'i18next';
import type { ChangeEventHandler} from 'react'; import type { ChangeEventHandler } from 'react';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { PiXBold } from 'react-icons/pi'; import { PiXBold } from 'react-icons/pi';

View File

@ -59,7 +59,7 @@ export const ModelConvert = (props: ModelConvertProps) => {
) )
); );
}); });
}, [convertModel, dispatch, model.base, model.name, t]); }, [convertModel, dispatch, model.key, model.name, t]);
return ( return (
<> <>

View File

@ -58,7 +58,6 @@ type DeleteImportModelsResponse =
type PruneModelImportsResponse = type PruneModelImportsResponse =
paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json']; paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json'];
export type ScanFolderResponse = export type ScanFolderResponse =
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json']; paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
type ScanFolderArg = operations['scan_for_models']['parameters']['query']; type ScanFolderArg = operations['scan_for_models']['parameters']['query'];
@ -104,25 +103,25 @@ export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined
const buildProvidesTags = const buildProvidesTags =
<TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) => <TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) =>
(result: EntityState<TEntity, string> | undefined) => { (result: EntityState<TEntity, string> | undefined) => {
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model']; const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
if (result) { if (result) {
tags.push( tags.push(
...result.ids.map((id) => ({ ...result.ids.map((id) => ({
type: tagType, type: tagType,
id, id,
})) }))
); );
} }
return tags; return tags;
}; };
const buildTransformResponse = const buildTransformResponse =
<T extends AnyModelConfig>(adapter: EntityAdapter<T, string>) => <T extends AnyModelConfig>(adapter: EntityAdapter<T, string>) =>
(response: { models: T[] }) => { (response: { models: T[] }) => {
return adapter.setAll(adapter.getInitialState(), response.models); return adapter.setAll(adapter.getInitialState(), response.models);
}; };
/** /**
* Builds an endpoint URL for the models router * Builds an endpoint URL for the models router

View File

@ -117,8 +117,8 @@ export const isRefinerMainModelModelConfig = (config: AnyModelConfig): config is
export type MergeModelConfig = S['Body_merge']; export type MergeModelConfig = S['Body_merge'];
export type ImportModelConfig = S['Body_import_model']; export type ImportModelConfig = S['Body_import_model'];
export type ModelInstallJob = S['ModelInstallJob'] export type ModelInstallJob = S['ModelInstallJob'];
export type ModelInstallStatus = S["InstallStatus"] export type ModelInstallStatus = S['InstallStatus'];
export type HFModelSource = S['HFModelSource']; export type HFModelSource = S['HFModelSource'];
export type CivitaiModelSource = S['CivitaiModelSource']; export type CivitaiModelSource = S['CivitaiModelSource'];

View File

@ -146,29 +146,29 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
data, data,
}) })
); );
}) });
/** /**
* Model Install Completed * Model Install Completed
*/ */
socket.on('model_install_completed', (data) => { socket.on('model_install_completed', (data) => {
dispatch( dispatch(
socketModelInstallCompleted({ socketModelInstallCompleted({
data, data,
}) })
); );
}) });
/** /**
* Model Install Error * Model Install Error
*/ */
socket.on('model_install_error', (data) => { socket.on('model_install_error', (data) => {
dispatch( dispatch(
socketModelInstallError({ socketModelInstallError({
data, data,
}) })
); );
}) });
/** /**
* Session retrieval error * Session retrieval error