diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index 53df14330f..b3ee70b359 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -357,6 +357,7 @@ class EventServiceBase: bytes: int, total_bytes: int, parts: List[Dict[str, Union[str, int]]], + id: int ) -> None: """ Emit at intervals while the install job is in progress (remote models only). @@ -376,6 +377,7 @@ class EventServiceBase: "bytes": bytes, "total_bytes": total_bytes, "parts": parts, + "id": id }, ) @@ -390,7 +392,7 @@ class EventServiceBase: payload={"source": source}, ) - def emit_model_install_completed(self, source: str, key: str, total_bytes: Optional[int] = None) -> None: + def emit_model_install_completed(self, source: str, key: str, id: int, total_bytes: Optional[int] = None) -> None: """ Emit when an install job is completed successfully. @@ -404,6 +406,7 @@ class EventServiceBase: "source": source, "total_bytes": total_bytes, "key": key, + "id": id }, ) @@ -423,6 +426,7 @@ class EventServiceBase: source: str, error_type: str, error: str, + id: int ) -> None: """ Emit when an install job encounters an exception. @@ -437,6 +441,7 @@ class EventServiceBase: "source": source, "error_type": error_type, "error": error, + "id": id }, ) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 2419fbe5da..ec330913c2 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -822,6 +822,7 @@ class ModelInstallService(ModelInstallServiceBase): parts=parts, bytes=job.bytes, total_bytes=job.total_bytes, + id=job.id ) def _signal_job_completed(self, job: ModelInstallJob) -> None: @@ -834,7 +835,7 @@ class ModelInstallService(ModelInstallServiceBase): assert job.local_path is not None assert job.config_out is not None key = job.config_out.key - self._event_bus.emit_model_install_completed(str(job.source), key) + self._event_bus.emit_model_install_completed(str(job.source), key, id=job.id) def _signal_job_errored(self, job: ModelInstallJob) -> None: self._logger.info(f"{job.source}: model installation encountered an exception: {job.error_type}\n{job.error}") @@ -843,7 +844,7 @@ class ModelInstallService(ModelInstallServiceBase): error = job.error assert error_type is not None assert error is not None - self._event_bus.emit_model_install_error(str(job.source), error_type, error) + self._event_bus.emit_model_install_error(str(job.source), error_type, error, id=job.id) def _signal_job_cancelled(self, job: ModelInstallJob) -> None: self._logger.info(f"{job.source}: model installation was cancelled") diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 406df71e06..4d40f7105c 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -797,6 +797,8 @@ "pathToCustomConfig": "Path To Custom Config", "pickModelType": "Pick Model Type", "predictionType": "Prediction Type (for Stable Diffusion 2.x Models and occasional Stable Diffusion 1.x Models)", + "prune": "Prune", + "pruneTooltip": "Prune finished imports from queue", "quickAdd": "Quick Add", "removeFromQueue": "Remove From Queue", "repo_id": "Repo ID", @@ -1354,6 +1356,8 @@ "modelAdded": "Model Added: {{modelName}}", "modelAddedSimple": "Model Added", "modelAddFailed": "Model Add Failed", + "modelImportCanceled": "Model Import Canceled", + "modelImportRemoved": "Model Import Removed", "nodesBrokenConnections": "Cannot load. Some connections are broken.", "nodesCorruptedGraph": "Cannot load. Graph seems to be corrupted.", "nodesLoaded": "Nodes Loaded", @@ -1387,6 +1391,7 @@ "promptNotSet": "Prompt Not Set", "promptNotSetDesc": "Could not find prompt for this image.", "promptSet": "Prompt Set", + "prunedQueue": "Pruned Queue", "resetInitialImage": "Reset Initial Image", "seedNotSet": "Seed Not Set", "seedNotSetDesc": "Could not find seed for this image.", diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index 26d7949e9c..6c5abf41a5 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -57,8 +57,8 @@ import { addInvocationCompleteEventListener as addInvocationCompleteListener } f import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError'; import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError'; import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted'; -import { addModelLoadEventListener } from './listeners/socketio/socketModelLoad'; import { addModelInstallEventListener } from './listeners/socketio/socketModelInstall'; +import { addModelLoadEventListener } from './listeners/socketio/socketModelLoad'; import { addSocketQueueItemStatusChangedEventListener } from './listeners/socketio/socketQueueItemStatusChanged'; import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError'; import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed'; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts index b241df27c8..a4cf8127cb 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts @@ -1,63 +1,65 @@ -import { logger } from 'app/logging/logger'; +import { modelsApi } from 'services/api/endpoints/models'; import { socketModelInstallCompleted, socketModelInstallDownloading, socketModelInstallError, } from 'services/events/actions'; -import { modelsApi } from 'services/api/endpoints/models'; -import type { components, paths } from 'services/api/schema'; - import { startAppListening } from '../..'; -import { createEntityAdapter } from '@reduxjs/toolkit'; - -const log = logger('socketio'); export const addModelInstallEventListener = () => { startAppListening({ actionCreator: socketModelInstallDownloading, effect: async (action, { dispatch }) => { - const { bytes, local_path, source, timestamp, total_bytes } = action.payload.data; - let message = `Model install started: ${bytes}/${total_bytes}/${source}`; - // below doesnt work, still not sure how to update the importModels data - // dispatch( - // modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => { - // importModelsAdapter.updateOne(draft, { - // id: source, - // changes: { - // bytes, - // total_bytes, - // },\q - // }); - // } - // ) - // ); + const { bytes, id } = action.payload.data; - log.debug(action.payload, message); + dispatch( + modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => { + const models = JSON.parse(JSON.stringify(draft)) + + const modelIndex = models.findIndex((m) => m.id === id); + + models[modelIndex].bytes = bytes; + models[modelIndex].status = 'downloading'; + return models; + }) + ); }, }); startAppListening({ actionCreator: socketModelInstallCompleted, - effect: (action) => { - const { key, source, timestamp } = action.payload.data; + effect: (action, { dispatch }) => { + const { id } = action.payload.data; - let message = `Model install completed: ${source}`; + dispatch( + modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => { + const models = JSON.parse(JSON.stringify(draft)) - // dispatch something that marks the model installation as completed + const modelIndex = models.findIndex((m) => m.id === id); - log.debug(action.payload, message); + models[modelIndex].status = 'completed'; + return models; + }) + ); }, }); startAppListening({ actionCreator: socketModelInstallError, - effect: (action) => { - const { error, error_type, source } = action.payload.data; + effect: (action, { dispatch }) => { + const { id } = action.payload.data; - // dispatch something that marks the model installation as errored + dispatch( + modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => { + const models = JSON.parse(JSON.stringify(draft)) - log.debug(action.payload, error); + const modelIndex = models.findIndex((m) => m.id === id); + + models[modelIndex].status = 'error'; + return models; + }) + ); }, }); }; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportModels.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportModels.tsx index 14ed0c6848..e2cdf24c2f 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportModels.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportModels.tsx @@ -1,11 +1,12 @@ -import { Button, Box, Flex, FormControl, FormLabel, Heading, Input, Text, Divider } from '@invoke-ai/ui-library'; -import { t } from 'i18next'; -import { CSSProperties } from 'react'; -import { useImportMainModelsMutation } from '../../../services/api/endpoints/models'; +import { Box, Button, Divider,Flex, FormControl, FormLabel, Heading, Input } from '@invoke-ai/ui-library'; import { useForm } from '@mantine/form'; import { useAppDispatch } from 'app/store/storeHooks'; import { addToast } from 'features/system/store/systemSlice'; import { makeToast } from 'features/system/util/makeToast'; +import { t } from 'i18next'; +import type { CSSProperties } from 'react'; +import { useImportMainModelsMutation } from 'services/api/endpoints/models'; + import { ImportQueue } from './ImportQueue'; const formStyles: CSSProperties = { @@ -27,8 +28,6 @@ export const ImportModels = () => { }, }); - console.log('addModelForm', addModelForm.values.location) - const handleAddModelSubmit = (values: ExtendedImportModelConfig) => { importMainModel({ source: values.location, config: undefined }) .unwrap() @@ -77,7 +76,6 @@ export const ImportModels = () => { </Flex> </form> <Divider mt="5" mb="3" /> - <Text>{t('modelManager.importQueue')}</Text> <ImportQueue /> </Box> </Box> diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportQueue.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportQueue.tsx index c2dbece4b3..53c037e74f 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportQueue.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportQueue.tsx @@ -1,85 +1,72 @@ -import { - Button, - Box, - Flex, - FormControl, - FormLabel, - Heading, - IconButton, - Input, - InputGroup, - InputRightElement, - Progress, - Text, -} from '@invoke-ai/ui-library'; -import { t } from 'i18next'; -import { useMemo } from 'react'; -import { useGetModelImportsQuery } from '../../../services/api/endpoints/models'; +import { Box, Button,Flex, Text } from '@invoke-ai/ui-library'; import { useAppDispatch } from 'app/store/storeHooks'; import { addToast } from 'features/system/store/systemSlice'; import { makeToast } from 'features/system/util/makeToast'; -import { PiXBold } from 'react-icons/pi'; +import { t } from 'i18next'; +import { useCallback, useMemo } from 'react'; +import { RiSparklingFill } from 'react-icons/ri'; +import { useGetModelImportsQuery, usePruneModelImportsMutation } from 'services/api/endpoints/models'; + +import { ImportQueueModel } from './ImportQueueModel'; export const ImportQueue = () => { const dispatch = useAppDispatch(); - // start with this data then pull from sockets (idk how to do that yet, also might not even use this and just use socket) const { data } = useGetModelImportsQuery(); - const progressValues = useMemo(() => { - if (!data) { - return []; - } - const values = []; - for (let i = 0; i < data.length; i++) { - let value; - if (data[i] && data[i]?.bytes && data[i]?.total_bytes) { - value = (data[i]?.bytes / data[i]?.total_bytes) * 100; - } - values.push(value || undefined); - } - return values; + const [pruneModelImports] = usePruneModelImportsMutation(); + + const pruneQueue = useCallback(() => { + pruneModelImports() + .unwrap() + .then((_) => { + dispatch( + addToast( + makeToast({ + title: t('toast.prunedQueue'), + status: 'success', + }) + ) + ); + }) + .catch((error) => { + if (error) { + dispatch( + addToast( + makeToast({ + title: `${error.data.detail} `, + status: 'error', + }) + ) + ); + } + }); + }, [pruneModelImports, dispatch]); + + const pruneAvailable = useMemo(() => { + return data?.some( + (model) => model.status === 'cancelled' || model.status === 'error' || model.status === 'completed' + ); }, [data]); return ( - <Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full"> - <Flex direction="column" gap="2"> - {data?.map((model, i) => ( - <Flex key={i} gap="3" w="full" alignItems="center" textAlign="center"> - <Text w="20%" whiteSpace="nowrap" overflow="hidden" text-overflow="ellipsis"> - {model.source.repo_id} - </Text> - <Progress - value={progressValues[i]} - isIndeterminate={progressValues[i] === undefined} - aria-label={t('accessibility.invokeProgressBar')} - h={2} - w="50%" - /> - <Text w="20%">{model.status}</Text> - {model.status === 'completed' ? ( - <IconButton - isRound={true} - size="xs" - tooltip={t('modelManager.removeFromQueue')} - aria-label={t('modelManager.removeFromQueue')} - icon={<PiXBold />} - // onClick={handleRemove} - /> - ) : ( - <IconButton - isRound={true} - size="xs" - tooltip={t('modelManager.cancel')} - aria-label={t('modelManager.cancel')} - icon={<PiXBold />} - // onClick={handleCancel} - colorScheme="error" - /> - )} - </Flex> - ))} + <> + <Flex justifyContent="space-between"> + <Text>{t('modelManager.importQueue')}</Text> + <Button + isDisabled={!pruneAvailable} + onClick={pruneQueue} + tooltip={t('modelManager.pruneTooltip')} + rightIcon={<RiSparklingFill />} + > + {t('modelManager.prune')} + </Button> </Flex> - </Box> + <Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full"> + <Flex direction="column" gap="2"> + {data?.map((model) => <ImportQueueModel key={model.id} model={model} />)} + </Flex> + </Box> + </> ); }; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportQueueModel.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportQueueModel.tsx new file mode 100644 index 0000000000..8704e847c6 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportQueueModel.tsx @@ -0,0 +1,82 @@ +import { + Flex, + IconButton, + Progress, + Text, +} from '@invoke-ai/ui-library'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { addToast } from 'features/system/store/systemSlice'; +import { makeToast } from 'features/system/util/makeToast'; +import { t } from 'i18next'; +import { useCallback, useMemo } from 'react'; +import { PiXBold } from 'react-icons/pi'; +import { useDeleteModelImportMutation } from 'services/api/endpoints/models'; +import type { ImportModelConfig } from 'services/api/types'; + +type ModelListItemProps = { + model: ImportModelConfig; +}; + +export const ImportQueueModel = (props: ModelListItemProps) => { + const { model } = props; + const dispatch = useAppDispatch(); + + const [deleteImportModel] = useDeleteModelImportMutation(); + + const handleDeleteModelImport = useCallback(() => { + deleteImportModel({ key: model.id }) + .unwrap() + .then((_) => { + dispatch( + addToast( + makeToast({ + title: t('toast.modelImportCanceled'), + status: 'success', + }) + ) + ); + }) + .catch((error) => { + if (error) { + dispatch( + addToast( + makeToast({ + title: `${error.data.detail} `, + status: 'error', + }) + ) + ); + } + }); + }, [deleteImportModel, model, dispatch]); + + const progressValue = useMemo(() => { + return (model.bytes / model.total_bytes) * 100; + }, [model.bytes, model.total_bytes]); + + return ( + <Flex gap="3" w="full" alignItems="center" textAlign="center"> + <Text w="20%" whiteSpace="nowrap" overflow="hidden" textOverflow="ellipsis"> + {model.source.repo_id} + </Text> + <Progress + value={progressValue} + isIndeterminate={progressValue === undefined} + aria-label={t('accessibility.invokeProgressBar')} + h={2} + w="50%" + /> + <Text w="20%">{model.status}</Text> + {(model.status === 'downloading' || model.status === 'waiting') && ( + <IconButton + isRound={true} + size="xs" + tooltip={t('modelManager.cancel')} + aria-label={t('modelManager.cancel')} + icon={<PiXBold />} + onClick={handleDeleteModelImport} + /> + )} + </Flex> + ); +}; diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 48ee2a7f08..3e1abc8560 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -26,11 +26,9 @@ type UpdateModelArg = { type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json']; - -type GetModelResponse = - paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json']; type GetModelMetadataResponse = paths['/api/v2/models/meta/i/{key}']['get']['responses']['200']['content']['application/json']; +type GetModelResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json']; type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>; @@ -68,13 +66,19 @@ type ImportMainModelResponse = type ListImportModelsResponse = paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json']; +type DeleteImportModelsResponse = + paths['/api/v2/models/import/{id}']['delete']['responses']['201']['content']['application/json']; + +type PruneModelImportsResponse = + paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json']; + type AddMainModelArg = { body: MainModelConfig; }; type AddMainModelResponse = paths['/api/v2/models/add']['post']['responses']['201']['content']['application/json']; -type SyncModelsResponse = paths['/api/v2/models/sync']['patch']['responses']['204']['content']['application/json']; +type SyncModelsResponse = paths['/api/v2/models/sync']['patch']['responses']['204']['content'] export type SearchFolderResponse = paths['/api/v2/models/search']['get']['responses']['200']['content']['application/json']; @@ -308,7 +312,25 @@ export const modelsApi = api.injectEndpoints({ url: buildModelsUrl(`import`), }; }, - providesTags: ['ModelImports'] + providesTags: ['ModelImports'], + }), + deleteModelImport: build.mutation<DeleteImportModelsResponse, DeleteMainModelArg>({ + query: ({ key }) => { + return { + url: buildModelsUrl(`import/${key}`), + method: 'DELETE', + }; + }, + invalidatesTags: ['ModelImports'], + }), + pruneModelImports: build.mutation<PruneModelImportsResponse, void>({ + query: () => { + return { + url: buildModelsUrl('import'), + method: 'PATCH', + }; + }, + invalidatesTags: ['ModelImports'], }), getCheckpointConfigs: build.query<CheckpointConfigsResponse, void>({ query: () => { @@ -339,6 +361,8 @@ export const { useGetModelsInFolderQuery, useGetCheckpointConfigsQuery, useGetModelImportsQuery, + useGetModelMetadataQuery, + useDeleteModelImportMutation, + usePruneModelImportsMutation, useGetModelQuery, - useGetModelMetadataQuery } = modelsApi; diff --git a/invokeai/frontend/web/src/services/events/actions.ts b/invokeai/frontend/web/src/services/events/actions.ts index c306b3d2f6..31f9e4d11d 100644 --- a/invokeai/frontend/web/src/services/events/actions.ts +++ b/invokeai/frontend/web/src/services/events/actions.ts @@ -9,11 +9,11 @@ import type { InvocationErrorEvent, InvocationRetrievalErrorEvent, InvocationStartedEvent, + ModelInstallCompletedEvent, + ModelInstallDownloadingEvent, + ModelInstallErrorEvent, ModelLoadCompletedEvent, ModelLoadStartedEvent, - ModelInstallDownloadingEvent, - ModelInstallCompletedEvent, - ModelInstallErrorEvent, QueueItemStatusChangedEvent, SessionRetrievalErrorEvent, } from 'services/events/types'; diff --git a/invokeai/frontend/web/src/services/events/types.ts b/invokeai/frontend/web/src/services/events/types.ts index 2cf73a8e69..e0c50f4def 100644 --- a/invokeai/frontend/web/src/services/events/types.ts +++ b/invokeai/frontend/web/src/services/events/types.ts @@ -51,12 +51,14 @@ export type ModelInstallDownloadingEvent = { source: string; timestamp: number; total_bytes: string; + id: number; }; export type ModelInstallCompletedEvent = { key: number; source: string; timestamp: number; + id: number; }; export type ModelInstallErrorEvent = { @@ -64,6 +66,7 @@ export type ModelInstallErrorEvent = { error_type: string; source: string; timestamp: number; + id: number; }; /** diff --git a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts index 314f9a985f..38f917e394 100644 --- a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts +++ b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts @@ -16,11 +16,11 @@ import { socketInvocationError, socketInvocationRetrievalError, socketInvocationStarted, + socketModelInstallCompleted, + socketModelInstallDownloading, + socketModelInstallError, socketModelLoadCompleted, socketModelLoadStarted, - socketModelInstallDownloading, - socketModelInstallCompleted, - socketModelInstallError, socketQueueItemStatusChanged, socketSessionRetrievalError, } from 'services/events/actions';