delete model imports and prune all finished, update state with socket messages

This commit is contained in:
Jennifer Player 2024-02-21 17:17:34 -05:00 committed by psychedelicious
parent 18904f79ef
commit ea364bdf82
12 changed files with 232 additions and 125 deletions

View File

@ -357,6 +357,7 @@ class EventServiceBase:
bytes: int, bytes: int,
total_bytes: int, total_bytes: int,
parts: List[Dict[str, Union[str, int]]], parts: List[Dict[str, Union[str, int]]],
id: int
) -> None: ) -> None:
""" """
Emit at intervals while the install job is in progress (remote models only). Emit at intervals while the install job is in progress (remote models only).
@ -376,6 +377,7 @@ class EventServiceBase:
"bytes": bytes, "bytes": bytes,
"total_bytes": total_bytes, "total_bytes": total_bytes,
"parts": parts, "parts": parts,
"id": id
}, },
) )
@ -390,7 +392,7 @@ class EventServiceBase:
payload={"source": source}, payload={"source": source},
) )
def emit_model_install_completed(self, source: str, key: str, total_bytes: Optional[int] = None) -> None: def emit_model_install_completed(self, source: str, key: str, id: int, total_bytes: Optional[int] = None) -> None:
""" """
Emit when an install job is completed successfully. Emit when an install job is completed successfully.
@ -404,6 +406,7 @@ class EventServiceBase:
"source": source, "source": source,
"total_bytes": total_bytes, "total_bytes": total_bytes,
"key": key, "key": key,
"id": id
}, },
) )
@ -423,6 +426,7 @@ class EventServiceBase:
source: str, source: str,
error_type: str, error_type: str,
error: str, error: str,
id: int
) -> None: ) -> None:
""" """
Emit when an install job encounters an exception. Emit when an install job encounters an exception.
@ -437,6 +441,7 @@ class EventServiceBase:
"source": source, "source": source,
"error_type": error_type, "error_type": error_type,
"error": error, "error": error,
"id": id
}, },
) )

View File

@ -822,6 +822,7 @@ class ModelInstallService(ModelInstallServiceBase):
parts=parts, parts=parts,
bytes=job.bytes, bytes=job.bytes,
total_bytes=job.total_bytes, total_bytes=job.total_bytes,
id=job.id
) )
def _signal_job_completed(self, job: ModelInstallJob) -> None: def _signal_job_completed(self, job: ModelInstallJob) -> None:
@ -834,7 +835,7 @@ class ModelInstallService(ModelInstallServiceBase):
assert job.local_path is not None assert job.local_path is not None
assert job.config_out is not None assert job.config_out is not None
key = job.config_out.key key = job.config_out.key
self._event_bus.emit_model_install_completed(str(job.source), key) self._event_bus.emit_model_install_completed(str(job.source), key, id=job.id)
def _signal_job_errored(self, job: ModelInstallJob) -> None: def _signal_job_errored(self, job: ModelInstallJob) -> None:
self._logger.info(f"{job.source}: model installation encountered an exception: {job.error_type}\n{job.error}") self._logger.info(f"{job.source}: model installation encountered an exception: {job.error_type}\n{job.error}")
@ -843,7 +844,7 @@ class ModelInstallService(ModelInstallServiceBase):
error = job.error error = job.error
assert error_type is not None assert error_type is not None
assert error is not None assert error is not None
self._event_bus.emit_model_install_error(str(job.source), error_type, error) self._event_bus.emit_model_install_error(str(job.source), error_type, error, id=job.id)
def _signal_job_cancelled(self, job: ModelInstallJob) -> None: def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
self._logger.info(f"{job.source}: model installation was cancelled") self._logger.info(f"{job.source}: model installation was cancelled")

View File

@ -797,6 +797,8 @@
"pathToCustomConfig": "Path To Custom Config", "pathToCustomConfig": "Path To Custom Config",
"pickModelType": "Pick Model Type", "pickModelType": "Pick Model Type",
"predictionType": "Prediction Type (for Stable Diffusion 2.x Models and occasional Stable Diffusion 1.x Models)", "predictionType": "Prediction Type (for Stable Diffusion 2.x Models and occasional Stable Diffusion 1.x Models)",
"prune": "Prune",
"pruneTooltip": "Prune finished imports from queue",
"quickAdd": "Quick Add", "quickAdd": "Quick Add",
"removeFromQueue": "Remove From Queue", "removeFromQueue": "Remove From Queue",
"repo_id": "Repo ID", "repo_id": "Repo ID",
@ -1354,6 +1356,8 @@
"modelAdded": "Model Added: {{modelName}}", "modelAdded": "Model Added: {{modelName}}",
"modelAddedSimple": "Model Added", "modelAddedSimple": "Model Added",
"modelAddFailed": "Model Add Failed", "modelAddFailed": "Model Add Failed",
"modelImportCanceled": "Model Import Canceled",
"modelImportRemoved": "Model Import Removed",
"nodesBrokenConnections": "Cannot load. Some connections are broken.", "nodesBrokenConnections": "Cannot load. Some connections are broken.",
"nodesCorruptedGraph": "Cannot load. Graph seems to be corrupted.", "nodesCorruptedGraph": "Cannot load. Graph seems to be corrupted.",
"nodesLoaded": "Nodes Loaded", "nodesLoaded": "Nodes Loaded",
@ -1387,6 +1391,7 @@
"promptNotSet": "Prompt Not Set", "promptNotSet": "Prompt Not Set",
"promptNotSetDesc": "Could not find prompt for this image.", "promptNotSetDesc": "Could not find prompt for this image.",
"promptSet": "Prompt Set", "promptSet": "Prompt Set",
"prunedQueue": "Pruned Queue",
"resetInitialImage": "Reset Initial Image", "resetInitialImage": "Reset Initial Image",
"seedNotSet": "Seed Not Set", "seedNotSet": "Seed Not Set",
"seedNotSetDesc": "Could not find seed for this image.", "seedNotSetDesc": "Could not find seed for this image.",

View File

@ -57,8 +57,8 @@ import { addInvocationCompleteEventListener as addInvocationCompleteListener } f
import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError'; import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError';
import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError'; import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError';
import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted'; import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted';
import { addModelLoadEventListener } from './listeners/socketio/socketModelLoad';
import { addModelInstallEventListener } from './listeners/socketio/socketModelInstall'; import { addModelInstallEventListener } from './listeners/socketio/socketModelInstall';
import { addModelLoadEventListener } from './listeners/socketio/socketModelLoad';
import { addSocketQueueItemStatusChangedEventListener } from './listeners/socketio/socketQueueItemStatusChanged'; import { addSocketQueueItemStatusChangedEventListener } from './listeners/socketio/socketQueueItemStatusChanged';
import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError'; import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError';
import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed'; import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';

View File

@ -1,63 +1,65 @@
import { logger } from 'app/logging/logger'; import { modelsApi } from 'services/api/endpoints/models';
import { import {
socketModelInstallCompleted, socketModelInstallCompleted,
socketModelInstallDownloading, socketModelInstallDownloading,
socketModelInstallError, socketModelInstallError,
} from 'services/events/actions'; } from 'services/events/actions';
import { modelsApi } from 'services/api/endpoints/models';
import type { components, paths } from 'services/api/schema';
import { startAppListening } from '../..'; import { startAppListening } from '../..';
import { createEntityAdapter } from '@reduxjs/toolkit';
const log = logger('socketio');
export const addModelInstallEventListener = () => { export const addModelInstallEventListener = () => {
startAppListening({ startAppListening({
actionCreator: socketModelInstallDownloading, actionCreator: socketModelInstallDownloading,
effect: async (action, { dispatch }) => { effect: async (action, { dispatch }) => {
const { bytes, local_path, source, timestamp, total_bytes } = action.payload.data; const { bytes, id } = action.payload.data;
let message = `Model install started: ${bytes}/${total_bytes}/${source}`;
// below doesnt work, still not sure how to update the importModels data
// dispatch(
// modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
// importModelsAdapter.updateOne(draft, {
// id: source,
// changes: {
// bytes,
// total_bytes,
// },\q
// });
// }
// )
// );
log.debug(action.payload, message); dispatch(
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
const models = JSON.parse(JSON.stringify(draft))
const modelIndex = models.findIndex((m) => m.id === id);
models[modelIndex].bytes = bytes;
models[modelIndex].status = 'downloading';
return models;
})
);
}, },
}); });
startAppListening({ startAppListening({
actionCreator: socketModelInstallCompleted, actionCreator: socketModelInstallCompleted,
effect: (action) => { effect: (action, { dispatch }) => {
const { key, source, timestamp } = action.payload.data; const { id } = action.payload.data;
let message = `Model install completed: ${source}`; dispatch(
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
const models = JSON.parse(JSON.stringify(draft))
// dispatch something that marks the model installation as completed const modelIndex = models.findIndex((m) => m.id === id);
log.debug(action.payload, message); models[modelIndex].status = 'completed';
return models;
})
);
}, },
}); });
startAppListening({ startAppListening({
actionCreator: socketModelInstallError, actionCreator: socketModelInstallError,
effect: (action) => { effect: (action, { dispatch }) => {
const { error, error_type, source } = action.payload.data; const { id } = action.payload.data;
// dispatch something that marks the model installation as errored dispatch(
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
const models = JSON.parse(JSON.stringify(draft))
log.debug(action.payload, error); const modelIndex = models.findIndex((m) => m.id === id);
models[modelIndex].status = 'error';
return models;
})
);
}, },
}); });
}; };

View File

@ -1,11 +1,12 @@
import { Button, Box, Flex, FormControl, FormLabel, Heading, Input, Text, Divider } from '@invoke-ai/ui-library'; import { Box, Button, Divider,Flex, FormControl, FormLabel, Heading, Input } from '@invoke-ai/ui-library';
import { t } from 'i18next';
import { CSSProperties } from 'react';
import { useImportMainModelsMutation } from '../../../services/api/endpoints/models';
import { useForm } from '@mantine/form'; import { useForm } from '@mantine/form';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast'; import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import type { CSSProperties } from 'react';
import { useImportMainModelsMutation } from 'services/api/endpoints/models';
import { ImportQueue } from './ImportQueue'; import { ImportQueue } from './ImportQueue';
const formStyles: CSSProperties = { const formStyles: CSSProperties = {
@ -27,8 +28,6 @@ export const ImportModels = () => {
}, },
}); });
console.log('addModelForm', addModelForm.values.location)
const handleAddModelSubmit = (values: ExtendedImportModelConfig) => { const handleAddModelSubmit = (values: ExtendedImportModelConfig) => {
importMainModel({ source: values.location, config: undefined }) importMainModel({ source: values.location, config: undefined })
.unwrap() .unwrap()
@ -77,7 +76,6 @@ export const ImportModels = () => {
</Flex> </Flex>
</form> </form>
<Divider mt="5" mb="3" /> <Divider mt="5" mb="3" />
<Text>{t('modelManager.importQueue')}</Text>
<ImportQueue /> <ImportQueue />
</Box> </Box>
</Box> </Box>

View File

@ -1,85 +1,72 @@
import { import { Box, Button,Flex, Text } from '@invoke-ai/ui-library';
Button,
Box,
Flex,
FormControl,
FormLabel,
Heading,
IconButton,
Input,
InputGroup,
InputRightElement,
Progress,
Text,
} from '@invoke-ai/ui-library';
import { t } from 'i18next';
import { useMemo } from 'react';
import { useGetModelImportsQuery } from '../../../services/api/endpoints/models';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast'; import { makeToast } from 'features/system/util/makeToast';
import { PiXBold } from 'react-icons/pi'; import { t } from 'i18next';
import { useCallback, useMemo } from 'react';
import { RiSparklingFill } from 'react-icons/ri';
import { useGetModelImportsQuery, usePruneModelImportsMutation } from 'services/api/endpoints/models';
import { ImportQueueModel } from './ImportQueueModel';
export const ImportQueue = () => { export const ImportQueue = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
// start with this data then pull from sockets (idk how to do that yet, also might not even use this and just use socket)
const { data } = useGetModelImportsQuery(); const { data } = useGetModelImportsQuery();
const progressValues = useMemo(() => { const [pruneModelImports] = usePruneModelImportsMutation();
if (!data) {
return []; const pruneQueue = useCallback(() => {
} pruneModelImports()
const values = []; .unwrap()
for (let i = 0; i < data.length; i++) { .then((_) => {
let value; dispatch(
if (data[i] && data[i]?.bytes && data[i]?.total_bytes) { addToast(
value = (data[i]?.bytes / data[i]?.total_bytes) * 100; makeToast({
} title: t('toast.prunedQueue'),
values.push(value || undefined); status: 'success',
} })
return values; )
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
}, [pruneModelImports, dispatch]);
const pruneAvailable = useMemo(() => {
return data?.some(
(model) => model.status === 'cancelled' || model.status === 'error' || model.status === 'completed'
);
}, [data]); }, [data]);
return ( return (
<Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full"> <>
<Flex direction="column" gap="2"> <Flex justifyContent="space-between">
{data?.map((model, i) => ( <Text>{t('modelManager.importQueue')}</Text>
<Flex key={i} gap="3" w="full" alignItems="center" textAlign="center"> <Button
<Text w="20%" whiteSpace="nowrap" overflow="hidden" text-overflow="ellipsis"> isDisabled={!pruneAvailable}
{model.source.repo_id} onClick={pruneQueue}
</Text> tooltip={t('modelManager.pruneTooltip')}
<Progress rightIcon={<RiSparklingFill />}
value={progressValues[i]} >
isIndeterminate={progressValues[i] === undefined} {t('modelManager.prune')}
aria-label={t('accessibility.invokeProgressBar')} </Button>
h={2}
w="50%"
/>
<Text w="20%">{model.status}</Text>
{model.status === 'completed' ? (
<IconButton
isRound={true}
size="xs"
tooltip={t('modelManager.removeFromQueue')}
aria-label={t('modelManager.removeFromQueue')}
icon={<PiXBold />}
// onClick={handleRemove}
/>
) : (
<IconButton
isRound={true}
size="xs"
tooltip={t('modelManager.cancel')}
aria-label={t('modelManager.cancel')}
icon={<PiXBold />}
// onClick={handleCancel}
colorScheme="error"
/>
)}
</Flex>
))}
</Flex> </Flex>
</Box> <Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full">
<Flex direction="column" gap="2">
{data?.map((model) => <ImportQueueModel key={model.id} model={model} />)}
</Flex>
</Box>
</>
); );
}; };

View File

@ -0,0 +1,82 @@
import {
Flex,
IconButton,
Progress,
Text,
} from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { useCallback, useMemo } from 'react';
import { PiXBold } from 'react-icons/pi';
import { useDeleteModelImportMutation } from 'services/api/endpoints/models';
import type { ImportModelConfig } from 'services/api/types';
type ModelListItemProps = {
model: ImportModelConfig;
};
export const ImportQueueModel = (props: ModelListItemProps) => {
const { model } = props;
const dispatch = useAppDispatch();
const [deleteImportModel] = useDeleteModelImportMutation();
const handleDeleteModelImport = useCallback(() => {
deleteImportModel({ key: model.id })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelImportCanceled'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
}, [deleteImportModel, model, dispatch]);
const progressValue = useMemo(() => {
return (model.bytes / model.total_bytes) * 100;
}, [model.bytes, model.total_bytes]);
return (
<Flex gap="3" w="full" alignItems="center" textAlign="center">
<Text w="20%" whiteSpace="nowrap" overflow="hidden" textOverflow="ellipsis">
{model.source.repo_id}
</Text>
<Progress
value={progressValue}
isIndeterminate={progressValue === undefined}
aria-label={t('accessibility.invokeProgressBar')}
h={2}
w="50%"
/>
<Text w="20%">{model.status}</Text>
{(model.status === 'downloading' || model.status === 'waiting') && (
<IconButton
isRound={true}
size="xs"
tooltip={t('modelManager.cancel')}
aria-label={t('modelManager.cancel')}
icon={<PiXBold />}
onClick={handleDeleteModelImport}
/>
)}
</Flex>
);
};

View File

@ -26,11 +26,9 @@ type UpdateModelArg = {
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json']; type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
type GetModelResponse =
paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
type GetModelMetadataResponse = type GetModelMetadataResponse =
paths['/api/v2/models/meta/i/{key}']['get']['responses']['200']['content']['application/json']; paths['/api/v2/models/meta/i/{key}']['get']['responses']['200']['content']['application/json'];
type GetModelResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>; type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
@ -68,13 +66,19 @@ type ImportMainModelResponse =
type ListImportModelsResponse = type ListImportModelsResponse =
paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json']; paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json'];
type DeleteImportModelsResponse =
paths['/api/v2/models/import/{id}']['delete']['responses']['201']['content']['application/json'];
type PruneModelImportsResponse =
paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json'];
type AddMainModelArg = { type AddMainModelArg = {
body: MainModelConfig; body: MainModelConfig;
}; };
type AddMainModelResponse = paths['/api/v2/models/add']['post']['responses']['201']['content']['application/json']; type AddMainModelResponse = paths['/api/v2/models/add']['post']['responses']['201']['content']['application/json'];
type SyncModelsResponse = paths['/api/v2/models/sync']['patch']['responses']['204']['content']['application/json']; type SyncModelsResponse = paths['/api/v2/models/sync']['patch']['responses']['204']['content']
export type SearchFolderResponse = export type SearchFolderResponse =
paths['/api/v2/models/search']['get']['responses']['200']['content']['application/json']; paths['/api/v2/models/search']['get']['responses']['200']['content']['application/json'];
@ -308,7 +312,25 @@ export const modelsApi = api.injectEndpoints({
url: buildModelsUrl(`import`), url: buildModelsUrl(`import`),
}; };
}, },
providesTags: ['ModelImports'] providesTags: ['ModelImports'],
}),
deleteModelImport: build.mutation<DeleteImportModelsResponse, DeleteMainModelArg>({
query: ({ key }) => {
return {
url: buildModelsUrl(`import/${key}`),
method: 'DELETE',
};
},
invalidatesTags: ['ModelImports'],
}),
pruneModelImports: build.mutation<PruneModelImportsResponse, void>({
query: () => {
return {
url: buildModelsUrl('import'),
method: 'PATCH',
};
},
invalidatesTags: ['ModelImports'],
}), }),
getCheckpointConfigs: build.query<CheckpointConfigsResponse, void>({ getCheckpointConfigs: build.query<CheckpointConfigsResponse, void>({
query: () => { query: () => {
@ -339,6 +361,8 @@ export const {
useGetModelsInFolderQuery, useGetModelsInFolderQuery,
useGetCheckpointConfigsQuery, useGetCheckpointConfigsQuery,
useGetModelImportsQuery, useGetModelImportsQuery,
useGetModelMetadataQuery,
useDeleteModelImportMutation,
usePruneModelImportsMutation,
useGetModelQuery, useGetModelQuery,
useGetModelMetadataQuery
} = modelsApi; } = modelsApi;

View File

@ -9,11 +9,11 @@ import type {
InvocationErrorEvent, InvocationErrorEvent,
InvocationRetrievalErrorEvent, InvocationRetrievalErrorEvent,
InvocationStartedEvent, InvocationStartedEvent,
ModelInstallCompletedEvent,
ModelInstallDownloadingEvent,
ModelInstallErrorEvent,
ModelLoadCompletedEvent, ModelLoadCompletedEvent,
ModelLoadStartedEvent, ModelLoadStartedEvent,
ModelInstallDownloadingEvent,
ModelInstallCompletedEvent,
ModelInstallErrorEvent,
QueueItemStatusChangedEvent, QueueItemStatusChangedEvent,
SessionRetrievalErrorEvent, SessionRetrievalErrorEvent,
} from 'services/events/types'; } from 'services/events/types';

View File

@ -51,12 +51,14 @@ export type ModelInstallDownloadingEvent = {
source: string; source: string;
timestamp: number; timestamp: number;
total_bytes: string; total_bytes: string;
id: number;
}; };
export type ModelInstallCompletedEvent = { export type ModelInstallCompletedEvent = {
key: number; key: number;
source: string; source: string;
timestamp: number; timestamp: number;
id: number;
}; };
export type ModelInstallErrorEvent = { export type ModelInstallErrorEvent = {
@ -64,6 +66,7 @@ export type ModelInstallErrorEvent = {
error_type: string; error_type: string;
source: string; source: string;
timestamp: number; timestamp: number;
id: number;
}; };
/** /**

View File

@ -16,11 +16,11 @@ import {
socketInvocationError, socketInvocationError,
socketInvocationRetrievalError, socketInvocationRetrievalError,
socketInvocationStarted, socketInvocationStarted,
socketModelInstallCompleted,
socketModelInstallDownloading,
socketModelInstallError,
socketModelLoadCompleted, socketModelLoadCompleted,
socketModelLoadStarted, socketModelLoadStarted,
socketModelInstallDownloading,
socketModelInstallCompleted,
socketModelInstallError,
socketQueueItemStatusChanged, socketQueueItemStatusChanged,
socketSessionRetrievalError, socketSessionRetrievalError,
} from 'services/events/actions'; } from 'services/events/actions';