diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py
index 53df14330f..b3ee70b359 100644
--- a/invokeai/app/services/events/events_base.py
+++ b/invokeai/app/services/events/events_base.py
@@ -357,6 +357,7 @@ class EventServiceBase:
bytes: int,
total_bytes: int,
parts: List[Dict[str, Union[str, int]]],
+ id: int
) -> None:
"""
Emit at intervals while the install job is in progress (remote models only).
@@ -376,6 +377,7 @@ class EventServiceBase:
"bytes": bytes,
"total_bytes": total_bytes,
"parts": parts,
+ "id": id
},
)
@@ -390,7 +392,7 @@ class EventServiceBase:
payload={"source": source},
)
- def emit_model_install_completed(self, source: str, key: str, total_bytes: Optional[int] = None) -> None:
+ def emit_model_install_completed(self, source: str, key: str, id: int, total_bytes: Optional[int] = None) -> None:
"""
Emit when an install job is completed successfully.
@@ -404,6 +406,7 @@ class EventServiceBase:
"source": source,
"total_bytes": total_bytes,
"key": key,
+ "id": id
},
)
@@ -423,6 +426,7 @@ class EventServiceBase:
source: str,
error_type: str,
error: str,
+ id: int
) -> None:
"""
Emit when an install job encounters an exception.
@@ -437,6 +441,7 @@ class EventServiceBase:
"source": source,
"error_type": error_type,
"error": error,
+ "id": id
},
)
diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py
index 2419fbe5da..ec330913c2 100644
--- a/invokeai/app/services/model_install/model_install_default.py
+++ b/invokeai/app/services/model_install/model_install_default.py
@@ -822,6 +822,7 @@ class ModelInstallService(ModelInstallServiceBase):
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
+ id=job.id
)
def _signal_job_completed(self, job: ModelInstallJob) -> None:
@@ -834,7 +835,7 @@ class ModelInstallService(ModelInstallServiceBase):
assert job.local_path is not None
assert job.config_out is not None
key = job.config_out.key
- self._event_bus.emit_model_install_completed(str(job.source), key)
+ self._event_bus.emit_model_install_completed(str(job.source), key, id=job.id)
def _signal_job_errored(self, job: ModelInstallJob) -> None:
self._logger.info(f"{job.source}: model installation encountered an exception: {job.error_type}\n{job.error}")
@@ -843,7 +844,7 @@ class ModelInstallService(ModelInstallServiceBase):
error = job.error
assert error_type is not None
assert error is not None
- self._event_bus.emit_model_install_error(str(job.source), error_type, error)
+ self._event_bus.emit_model_install_error(str(job.source), error_type, error, id=job.id)
def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
self._logger.info(f"{job.source}: model installation was cancelled")
diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json
index 4e45d8ab23..7d3387cc3f 100644
--- a/invokeai/frontend/web/public/locales/en.json
+++ b/invokeai/frontend/web/public/locales/en.json
@@ -797,6 +797,8 @@
"pathToCustomConfig": "Path To Custom Config",
"pickModelType": "Pick Model Type",
"predictionType": "Prediction Type (for Stable Diffusion 2.x Models and occasional Stable Diffusion 1.x Models)",
+ "prune": "Prune",
+ "pruneTooltip": "Prune finished imports from queue",
"quickAdd": "Quick Add",
"removeFromQueue": "Remove From Queue",
"repo_id": "Repo ID",
@@ -1354,6 +1356,8 @@
"modelAdded": "Model Added: {{modelName}}",
"modelAddedSimple": "Model Added",
"modelAddFailed": "Model Add Failed",
+ "modelImportCanceled": "Model Import Canceled",
+ "modelImportRemoved": "Model Import Removed",
"nodesBrokenConnections": "Cannot load. Some connections are broken.",
"nodesCorruptedGraph": "Cannot load. Graph seems to be corrupted.",
"nodesLoaded": "Nodes Loaded",
@@ -1387,6 +1391,7 @@
"promptNotSet": "Prompt Not Set",
"promptNotSetDesc": "Could not find prompt for this image.",
"promptSet": "Prompt Set",
+ "prunedQueue": "Pruned Queue",
"resetInitialImage": "Reset Initial Image",
"seedNotSet": "Seed Not Set",
"seedNotSetDesc": "Could not find seed for this image.",
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts
index 26d7949e9c..6c5abf41a5 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts
@@ -57,8 +57,8 @@ import { addInvocationCompleteEventListener as addInvocationCompleteListener } f
import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError';
import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError';
import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted';
-import { addModelLoadEventListener } from './listeners/socketio/socketModelLoad';
import { addModelInstallEventListener } from './listeners/socketio/socketModelInstall';
+import { addModelLoadEventListener } from './listeners/socketio/socketModelLoad';
import { addSocketQueueItemStatusChangedEventListener } from './listeners/socketio/socketQueueItemStatusChanged';
import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError';
import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts
index b241df27c8..a4cf8127cb 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts
@@ -1,63 +1,65 @@
-import { logger } from 'app/logging/logger';
+import { modelsApi } from 'services/api/endpoints/models';
import {
socketModelInstallCompleted,
socketModelInstallDownloading,
socketModelInstallError,
} from 'services/events/actions';
-import { modelsApi } from 'services/api/endpoints/models';
-import type { components, paths } from 'services/api/schema';
-
import { startAppListening } from '../..';
-import { createEntityAdapter } from '@reduxjs/toolkit';
-
-const log = logger('socketio');
export const addModelInstallEventListener = () => {
startAppListening({
actionCreator: socketModelInstallDownloading,
effect: async (action, { dispatch }) => {
- const { bytes, local_path, source, timestamp, total_bytes } = action.payload.data;
- let message = `Model install started: ${bytes}/${total_bytes}/${source}`;
- // below doesnt work, still not sure how to update the importModels data
- // dispatch(
- // modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
- // importModelsAdapter.updateOne(draft, {
- // id: source,
- // changes: {
- // bytes,
- // total_bytes,
- // },\q
- // });
- // }
- // )
- // );
+ const { bytes, id } = action.payload.data;
- log.debug(action.payload, message);
+ dispatch(
+ modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
+ const models = JSON.parse(JSON.stringify(draft))
+
+ const modelIndex = models.findIndex((m) => m.id === id);
+
+ models[modelIndex].bytes = bytes;
+ models[modelIndex].status = 'downloading';
+ return models;
+ })
+ );
},
});
startAppListening({
actionCreator: socketModelInstallCompleted,
- effect: (action) => {
- const { key, source, timestamp } = action.payload.data;
+ effect: (action, { dispatch }) => {
+ const { id } = action.payload.data;
- let message = `Model install completed: ${source}`;
+ dispatch(
+ modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
+ const models = JSON.parse(JSON.stringify(draft))
- // dispatch something that marks the model installation as completed
+ const modelIndex = models.findIndex((m) => m.id === id);
- log.debug(action.payload, message);
+ models[modelIndex].status = 'completed';
+ return models;
+ })
+ );
},
});
startAppListening({
actionCreator: socketModelInstallError,
- effect: (action) => {
- const { error, error_type, source } = action.payload.data;
+ effect: (action, { dispatch }) => {
+ const { id } = action.payload.data;
- // dispatch something that marks the model installation as errored
+ dispatch(
+ modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
+ const models = JSON.parse(JSON.stringify(draft))
- log.debug(action.payload, error);
+ const modelIndex = models.findIndex((m) => m.id === id);
+
+ models[modelIndex].status = 'error';
+ return models;
+ })
+ );
},
});
};
diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportModels.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportModels.tsx
index 14ed0c6848..e2cdf24c2f 100644
--- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportModels.tsx
+++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportModels.tsx
@@ -1,11 +1,12 @@
-import { Button, Box, Flex, FormControl, FormLabel, Heading, Input, Text, Divider } from '@invoke-ai/ui-library';
-import { t } from 'i18next';
-import { CSSProperties } from 'react';
-import { useImportMainModelsMutation } from '../../../services/api/endpoints/models';
+import { Box, Button, Divider,Flex, FormControl, FormLabel, Heading, Input } from '@invoke-ai/ui-library';
import { useForm } from '@mantine/form';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
+import { t } from 'i18next';
+import type { CSSProperties } from 'react';
+import { useImportMainModelsMutation } from 'services/api/endpoints/models';
+
import { ImportQueue } from './ImportQueue';
const formStyles: CSSProperties = {
@@ -27,8 +28,6 @@ export const ImportModels = () => {
},
});
- console.log('addModelForm', addModelForm.values.location)
-
const handleAddModelSubmit = (values: ExtendedImportModelConfig) => {
importMainModel({ source: values.location, config: undefined })
.unwrap()
@@ -77,7 +76,6 @@ export const ImportModels = () => {
- {t('modelManager.importQueue')}
diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportQueue.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportQueue.tsx
index c2dbece4b3..53c037e74f 100644
--- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportQueue.tsx
+++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportQueue.tsx
@@ -1,85 +1,72 @@
-import {
- Button,
- Box,
- Flex,
- FormControl,
- FormLabel,
- Heading,
- IconButton,
- Input,
- InputGroup,
- InputRightElement,
- Progress,
- Text,
-} from '@invoke-ai/ui-library';
-import { t } from 'i18next';
-import { useMemo } from 'react';
-import { useGetModelImportsQuery } from '../../../services/api/endpoints/models';
+import { Box, Button,Flex, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
-import { PiXBold } from 'react-icons/pi';
+import { t } from 'i18next';
+import { useCallback, useMemo } from 'react';
+import { RiSparklingFill } from 'react-icons/ri';
+import { useGetModelImportsQuery, usePruneModelImportsMutation } from 'services/api/endpoints/models';
+
+import { ImportQueueModel } from './ImportQueueModel';
export const ImportQueue = () => {
const dispatch = useAppDispatch();
- // start with this data then pull from sockets (idk how to do that yet, also might not even use this and just use socket)
const { data } = useGetModelImportsQuery();
- const progressValues = useMemo(() => {
- if (!data) {
- return [];
- }
- const values = [];
- for (let i = 0; i < data.length; i++) {
- let value;
- if (data[i] && data[i]?.bytes && data[i]?.total_bytes) {
- value = (data[i]?.bytes / data[i]?.total_bytes) * 100;
- }
- values.push(value || undefined);
- }
- return values;
+ const [pruneModelImports] = usePruneModelImportsMutation();
+
+ const pruneQueue = useCallback(() => {
+ pruneModelImports()
+ .unwrap()
+ .then((_) => {
+ dispatch(
+ addToast(
+ makeToast({
+ title: t('toast.prunedQueue'),
+ status: 'success',
+ })
+ )
+ );
+ })
+ .catch((error) => {
+ if (error) {
+ dispatch(
+ addToast(
+ makeToast({
+ title: `${error.data.detail} `,
+ status: 'error',
+ })
+ )
+ );
+ }
+ });
+ }, [pruneModelImports, dispatch]);
+
+ const pruneAvailable = useMemo(() => {
+ return data?.some(
+ (model) => model.status === 'cancelled' || model.status === 'error' || model.status === 'completed'
+ );
}, [data]);
return (
-
-
- {data?.map((model, i) => (
-
-
- {model.source.repo_id}
-
-
- {model.status}
- {model.status === 'completed' ? (
- }
- // onClick={handleRemove}
- />
- ) : (
- }
- // onClick={handleCancel}
- colorScheme="error"
- />
- )}
-
- ))}
+ <>
+
+ {t('modelManager.importQueue')}
+ }
+ >
+ {t('modelManager.prune')}
+
-
+
+
+ {data?.map((model) => )}
+
+
+ >
);
};
diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportQueueModel.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportQueueModel.tsx
new file mode 100644
index 0000000000..8704e847c6
--- /dev/null
+++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportQueueModel.tsx
@@ -0,0 +1,82 @@
+import {
+ Flex,
+ IconButton,
+ Progress,
+ Text,
+} from '@invoke-ai/ui-library';
+import { useAppDispatch } from 'app/store/storeHooks';
+import { addToast } from 'features/system/store/systemSlice';
+import { makeToast } from 'features/system/util/makeToast';
+import { t } from 'i18next';
+import { useCallback, useMemo } from 'react';
+import { PiXBold } from 'react-icons/pi';
+import { useDeleteModelImportMutation } from 'services/api/endpoints/models';
+import type { ImportModelConfig } from 'services/api/types';
+
+type ModelListItemProps = {
+ model: ImportModelConfig;
+};
+
+export const ImportQueueModel = (props: ModelListItemProps) => {
+ const { model } = props;
+ const dispatch = useAppDispatch();
+
+ const [deleteImportModel] = useDeleteModelImportMutation();
+
+ const handleDeleteModelImport = useCallback(() => {
+ deleteImportModel({ key: model.id })
+ .unwrap()
+ .then((_) => {
+ dispatch(
+ addToast(
+ makeToast({
+ title: t('toast.modelImportCanceled'),
+ status: 'success',
+ })
+ )
+ );
+ })
+ .catch((error) => {
+ if (error) {
+ dispatch(
+ addToast(
+ makeToast({
+ title: `${error.data.detail} `,
+ status: 'error',
+ })
+ )
+ );
+ }
+ });
+ }, [deleteImportModel, model, dispatch]);
+
+ const progressValue = useMemo(() => {
+ return (model.bytes / model.total_bytes) * 100;
+ }, [model.bytes, model.total_bytes]);
+
+ return (
+
+
+ {model.source.repo_id}
+
+
+ {model.status}
+ {(model.status === 'downloading' || model.status === 'waiting') && (
+ }
+ onClick={handleDeleteModelImport}
+ />
+ )}
+
+ );
+};
diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts
index 48ee2a7f08..3e1abc8560 100644
--- a/invokeai/frontend/web/src/services/api/endpoints/models.ts
+++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts
@@ -26,11 +26,9 @@ type UpdateModelArg = {
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
-
-type GetModelResponse =
- paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
type GetModelMetadataResponse =
paths['/api/v2/models/meta/i/{key}']['get']['responses']['200']['content']['application/json'];
+type GetModelResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
type ListModelsArg = NonNullable;
@@ -68,13 +66,19 @@ type ImportMainModelResponse =
type ListImportModelsResponse =
paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json'];
+type DeleteImportModelsResponse =
+ paths['/api/v2/models/import/{id}']['delete']['responses']['201']['content']['application/json'];
+
+type PruneModelImportsResponse =
+ paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json'];
+
type AddMainModelArg = {
body: MainModelConfig;
};
type AddMainModelResponse = paths['/api/v2/models/add']['post']['responses']['201']['content']['application/json'];
-type SyncModelsResponse = paths['/api/v2/models/sync']['patch']['responses']['204']['content']['application/json'];
+type SyncModelsResponse = paths['/api/v2/models/sync']['patch']['responses']['204']['content']
export type SearchFolderResponse =
paths['/api/v2/models/search']['get']['responses']['200']['content']['application/json'];
@@ -308,7 +312,25 @@ export const modelsApi = api.injectEndpoints({
url: buildModelsUrl(`import`),
};
},
- providesTags: ['ModelImports']
+ providesTags: ['ModelImports'],
+ }),
+ deleteModelImport: build.mutation({
+ query: ({ key }) => {
+ return {
+ url: buildModelsUrl(`import/${key}`),
+ method: 'DELETE',
+ };
+ },
+ invalidatesTags: ['ModelImports'],
+ }),
+ pruneModelImports: build.mutation({
+ query: () => {
+ return {
+ url: buildModelsUrl('import'),
+ method: 'PATCH',
+ };
+ },
+ invalidatesTags: ['ModelImports'],
}),
getCheckpointConfigs: build.query({
query: () => {
@@ -339,6 +361,8 @@ export const {
useGetModelsInFolderQuery,
useGetCheckpointConfigsQuery,
useGetModelImportsQuery,
+ useGetModelMetadataQuery,
+ useDeleteModelImportMutation,
+ usePruneModelImportsMutation,
useGetModelQuery,
- useGetModelMetadataQuery
} = modelsApi;
diff --git a/invokeai/frontend/web/src/services/events/actions.ts b/invokeai/frontend/web/src/services/events/actions.ts
index c306b3d2f6..31f9e4d11d 100644
--- a/invokeai/frontend/web/src/services/events/actions.ts
+++ b/invokeai/frontend/web/src/services/events/actions.ts
@@ -9,11 +9,11 @@ import type {
InvocationErrorEvent,
InvocationRetrievalErrorEvent,
InvocationStartedEvent,
+ ModelInstallCompletedEvent,
+ ModelInstallDownloadingEvent,
+ ModelInstallErrorEvent,
ModelLoadCompletedEvent,
ModelLoadStartedEvent,
- ModelInstallDownloadingEvent,
- ModelInstallCompletedEvent,
- ModelInstallErrorEvent,
QueueItemStatusChangedEvent,
SessionRetrievalErrorEvent,
} from 'services/events/types';
diff --git a/invokeai/frontend/web/src/services/events/types.ts b/invokeai/frontend/web/src/services/events/types.ts
index 2cf73a8e69..e0c50f4def 100644
--- a/invokeai/frontend/web/src/services/events/types.ts
+++ b/invokeai/frontend/web/src/services/events/types.ts
@@ -51,12 +51,14 @@ export type ModelInstallDownloadingEvent = {
source: string;
timestamp: number;
total_bytes: string;
+ id: number;
};
export type ModelInstallCompletedEvent = {
key: number;
source: string;
timestamp: number;
+ id: number;
};
export type ModelInstallErrorEvent = {
@@ -64,6 +66,7 @@ export type ModelInstallErrorEvent = {
error_type: string;
source: string;
timestamp: number;
+ id: number;
};
/**
diff --git a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts
index 314f9a985f..38f917e394 100644
--- a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts
+++ b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts
@@ -16,11 +16,11 @@ import {
socketInvocationError,
socketInvocationRetrievalError,
socketInvocationStarted,
+ socketModelInstallCompleted,
+ socketModelInstallDownloading,
+ socketModelInstallError,
socketModelLoadCompleted,
socketModelLoadStarted,
- socketModelInstallDownloading,
- socketModelInstallCompleted,
- socketModelInstallError,
socketQueueItemStatusChanged,
socketSessionRetrievalError,
} from 'services/events/actions';