delete model imports and prune all finished, update state with socket messages

This commit is contained in:
Jennifer Player 2024-02-21 17:17:34 -05:00 committed by psychedelicious
parent 66692f02aa
commit 23c412e011
12 changed files with 232 additions and 125 deletions

View File

@ -357,6 +357,7 @@ class EventServiceBase:
bytes: int,
total_bytes: int,
parts: List[Dict[str, Union[str, int]]],
id: int
) -> None:
"""
Emit at intervals while the install job is in progress (remote models only).
@ -376,6 +377,7 @@ class EventServiceBase:
"bytes": bytes,
"total_bytes": total_bytes,
"parts": parts,
"id": id
},
)
@ -390,7 +392,7 @@ class EventServiceBase:
payload={"source": source},
)
def emit_model_install_completed(self, source: str, key: str, total_bytes: Optional[int] = None) -> None:
def emit_model_install_completed(self, source: str, key: str, id: int, total_bytes: Optional[int] = None) -> None:
"""
Emit when an install job is completed successfully.
@ -404,6 +406,7 @@ class EventServiceBase:
"source": source,
"total_bytes": total_bytes,
"key": key,
"id": id
},
)
@ -423,6 +426,7 @@ class EventServiceBase:
source: str,
error_type: str,
error: str,
id: int
) -> None:
"""
Emit when an install job encounters an exception.
@ -437,6 +441,7 @@ class EventServiceBase:
"source": source,
"error_type": error_type,
"error": error,
"id": id
},
)

View File

@ -822,6 +822,7 @@ class ModelInstallService(ModelInstallServiceBase):
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
id=job.id
)
def _signal_job_completed(self, job: ModelInstallJob) -> None:
@ -834,7 +835,7 @@ class ModelInstallService(ModelInstallServiceBase):
assert job.local_path is not None
assert job.config_out is not None
key = job.config_out.key
self._event_bus.emit_model_install_completed(str(job.source), key)
self._event_bus.emit_model_install_completed(str(job.source), key, id=job.id)
def _signal_job_errored(self, job: ModelInstallJob) -> None:
self._logger.info(f"{job.source}: model installation encountered an exception: {job.error_type}\n{job.error}")
@ -843,7 +844,7 @@ class ModelInstallService(ModelInstallServiceBase):
error = job.error
assert error_type is not None
assert error is not None
self._event_bus.emit_model_install_error(str(job.source), error_type, error)
self._event_bus.emit_model_install_error(str(job.source), error_type, error, id=job.id)
def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
self._logger.info(f"{job.source}: model installation was cancelled")

View File

@ -797,6 +797,8 @@
"pathToCustomConfig": "Path To Custom Config",
"pickModelType": "Pick Model Type",
"predictionType": "Prediction Type (for Stable Diffusion 2.x Models and occasional Stable Diffusion 1.x Models)",
"prune": "Prune",
"pruneTooltip": "Prune finished imports from queue",
"quickAdd": "Quick Add",
"removeFromQueue": "Remove From Queue",
"repo_id": "Repo ID",
@ -1354,6 +1356,8 @@
"modelAdded": "Model Added: {{modelName}}",
"modelAddedSimple": "Model Added",
"modelAddFailed": "Model Add Failed",
"modelImportCanceled": "Model Import Canceled",
"modelImportRemoved": "Model Import Removed",
"nodesBrokenConnections": "Cannot load. Some connections are broken.",
"nodesCorruptedGraph": "Cannot load. Graph seems to be corrupted.",
"nodesLoaded": "Nodes Loaded",
@ -1387,6 +1391,7 @@
"promptNotSet": "Prompt Not Set",
"promptNotSetDesc": "Could not find prompt for this image.",
"promptSet": "Prompt Set",
"prunedQueue": "Pruned Queue",
"resetInitialImage": "Reset Initial Image",
"seedNotSet": "Seed Not Set",
"seedNotSetDesc": "Could not find seed for this image.",

View File

@ -57,8 +57,8 @@ import { addInvocationCompleteEventListener as addInvocationCompleteListener } f
import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError';
import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError';
import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted';
import { addModelLoadEventListener } from './listeners/socketio/socketModelLoad';
import { addModelInstallEventListener } from './listeners/socketio/socketModelInstall';
import { addModelLoadEventListener } from './listeners/socketio/socketModelLoad';
import { addSocketQueueItemStatusChangedEventListener } from './listeners/socketio/socketQueueItemStatusChanged';
import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError';
import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';

View File

@ -1,63 +1,65 @@
import { logger } from 'app/logging/logger';
import { modelsApi } from 'services/api/endpoints/models';
import {
socketModelInstallCompleted,
socketModelInstallDownloading,
socketModelInstallError,
} from 'services/events/actions';
import { modelsApi } from 'services/api/endpoints/models';
import type { components, paths } from 'services/api/schema';
import { startAppListening } from '../..';
import { createEntityAdapter } from '@reduxjs/toolkit';
const log = logger('socketio');
export const addModelInstallEventListener = () => {
startAppListening({
actionCreator: socketModelInstallDownloading,
effect: async (action, { dispatch }) => {
const { bytes, local_path, source, timestamp, total_bytes } = action.payload.data;
let message = `Model install started: ${bytes}/${total_bytes}/${source}`;
// below doesnt work, still not sure how to update the importModels data
// dispatch(
// modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
// importModelsAdapter.updateOne(draft, {
// id: source,
// changes: {
// bytes,
// total_bytes,
// },\q
// });
// }
// )
// );
const { bytes, id } = action.payload.data;
log.debug(action.payload, message);
dispatch(
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
const models = JSON.parse(JSON.stringify(draft))
const modelIndex = models.findIndex((m) => m.id === id);
models[modelIndex].bytes = bytes;
models[modelIndex].status = 'downloading';
return models;
})
);
},
});
startAppListening({
actionCreator: socketModelInstallCompleted,
effect: (action) => {
const { key, source, timestamp } = action.payload.data;
effect: (action, { dispatch }) => {
const { id } = action.payload.data;
let message = `Model install completed: ${source}`;
dispatch(
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
const models = JSON.parse(JSON.stringify(draft))
// dispatch something that marks the model installation as completed
const modelIndex = models.findIndex((m) => m.id === id);
log.debug(action.payload, message);
models[modelIndex].status = 'completed';
return models;
})
);
},
});
startAppListening({
actionCreator: socketModelInstallError,
effect: (action) => {
const { error, error_type, source } = action.payload.data;
effect: (action, { dispatch }) => {
const { id } = action.payload.data;
// dispatch something that marks the model installation as errored
dispatch(
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
const models = JSON.parse(JSON.stringify(draft))
log.debug(action.payload, error);
const modelIndex = models.findIndex((m) => m.id === id);
models[modelIndex].status = 'error';
return models;
})
);
},
});
};

View File

@ -1,11 +1,12 @@
import { Button, Box, Flex, FormControl, FormLabel, Heading, Input, Text, Divider } from '@invoke-ai/ui-library';
import { t } from 'i18next';
import { CSSProperties } from 'react';
import { useImportMainModelsMutation } from '../../../services/api/endpoints/models';
import { Box, Button, Divider,Flex, FormControl, FormLabel, Heading, Input } from '@invoke-ai/ui-library';
import { useForm } from '@mantine/form';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import type { CSSProperties } from 'react';
import { useImportMainModelsMutation } from 'services/api/endpoints/models';
import { ImportQueue } from './ImportQueue';
const formStyles: CSSProperties = {
@ -27,8 +28,6 @@ export const ImportModels = () => {
},
});
console.log('addModelForm', addModelForm.values.location)
const handleAddModelSubmit = (values: ExtendedImportModelConfig) => {
importMainModel({ source: values.location, config: undefined })
.unwrap()
@ -77,7 +76,6 @@ export const ImportModels = () => {
</Flex>
</form>
<Divider mt="5" mb="3" />
<Text>{t('modelManager.importQueue')}</Text>
<ImportQueue />
</Box>
</Box>

View File

@ -1,85 +1,72 @@
import {
Button,
Box,
Flex,
FormControl,
FormLabel,
Heading,
IconButton,
Input,
InputGroup,
InputRightElement,
Progress,
Text,
} from '@invoke-ai/ui-library';
import { t } from 'i18next';
import { useMemo } from 'react';
import { useGetModelImportsQuery } from '../../../services/api/endpoints/models';
import { Box, Button,Flex, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { PiXBold } from 'react-icons/pi';
import { t } from 'i18next';
import { useCallback, useMemo } from 'react';
import { RiSparklingFill } from 'react-icons/ri';
import { useGetModelImportsQuery, usePruneModelImportsMutation } from 'services/api/endpoints/models';
import { ImportQueueModel } from './ImportQueueModel';
export const ImportQueue = () => {
const dispatch = useAppDispatch();
// start with this data then pull from sockets (idk how to do that yet, also might not even use this and just use socket)
const { data } = useGetModelImportsQuery();
const progressValues = useMemo(() => {
if (!data) {
return [];
}
const values = [];
for (let i = 0; i < data.length; i++) {
let value;
if (data[i] && data[i]?.bytes && data[i]?.total_bytes) {
value = (data[i]?.bytes / data[i]?.total_bytes) * 100;
}
values.push(value || undefined);
}
return values;
const [pruneModelImports] = usePruneModelImportsMutation();
const pruneQueue = useCallback(() => {
pruneModelImports()
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.prunedQueue'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
}, [pruneModelImports, dispatch]);
const pruneAvailable = useMemo(() => {
return data?.some(
(model) => model.status === 'cancelled' || model.status === 'error' || model.status === 'completed'
);
}, [data]);
return (
<Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full">
<Flex direction="column" gap="2">
{data?.map((model, i) => (
<Flex key={i} gap="3" w="full" alignItems="center" textAlign="center">
<Text w="20%" whiteSpace="nowrap" overflow="hidden" text-overflow="ellipsis">
{model.source.repo_id}
</Text>
<Progress
value={progressValues[i]}
isIndeterminate={progressValues[i] === undefined}
aria-label={t('accessibility.invokeProgressBar')}
h={2}
w="50%"
/>
<Text w="20%">{model.status}</Text>
{model.status === 'completed' ? (
<IconButton
isRound={true}
size="xs"
tooltip={t('modelManager.removeFromQueue')}
aria-label={t('modelManager.removeFromQueue')}
icon={<PiXBold />}
// onClick={handleRemove}
/>
) : (
<IconButton
isRound={true}
size="xs"
tooltip={t('modelManager.cancel')}
aria-label={t('modelManager.cancel')}
icon={<PiXBold />}
// onClick={handleCancel}
colorScheme="error"
/>
)}
</Flex>
))}
<>
<Flex justifyContent="space-between">
<Text>{t('modelManager.importQueue')}</Text>
<Button
isDisabled={!pruneAvailable}
onClick={pruneQueue}
tooltip={t('modelManager.pruneTooltip')}
rightIcon={<RiSparklingFill />}
>
{t('modelManager.prune')}
</Button>
</Flex>
</Box>
<Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full">
<Flex direction="column" gap="2">
{data?.map((model) => <ImportQueueModel key={model.id} model={model} />)}
</Flex>
</Box>
</>
);
};

View File

@ -0,0 +1,82 @@
import {
Flex,
IconButton,
Progress,
Text,
} from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { useCallback, useMemo } from 'react';
import { PiXBold } from 'react-icons/pi';
import { useDeleteModelImportMutation } from 'services/api/endpoints/models';
import type { ImportModelConfig } from 'services/api/types';
type ModelListItemProps = {
model: ImportModelConfig;
};
export const ImportQueueModel = (props: ModelListItemProps) => {
const { model } = props;
const dispatch = useAppDispatch();
const [deleteImportModel] = useDeleteModelImportMutation();
const handleDeleteModelImport = useCallback(() => {
deleteImportModel({ key: model.id })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelImportCanceled'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
}, [deleteImportModel, model, dispatch]);
const progressValue = useMemo(() => {
return (model.bytes / model.total_bytes) * 100;
}, [model.bytes, model.total_bytes]);
return (
<Flex gap="3" w="full" alignItems="center" textAlign="center">
<Text w="20%" whiteSpace="nowrap" overflow="hidden" textOverflow="ellipsis">
{model.source.repo_id}
</Text>
<Progress
value={progressValue}
isIndeterminate={progressValue === undefined}
aria-label={t('accessibility.invokeProgressBar')}
h={2}
w="50%"
/>
<Text w="20%">{model.status}</Text>
{(model.status === 'downloading' || model.status === 'waiting') && (
<IconButton
isRound={true}
size="xs"
tooltip={t('modelManager.cancel')}
aria-label={t('modelManager.cancel')}
icon={<PiXBold />}
onClick={handleDeleteModelImport}
/>
)}
</Flex>
);
};

View File

@ -26,11 +26,9 @@ type UpdateModelArg = {
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
type GetModelResponse =
paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
type GetModelMetadataResponse =
paths['/api/v2/models/meta/i/{key}']['get']['responses']['200']['content']['application/json'];
type GetModelResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
@ -68,13 +66,19 @@ type ImportMainModelResponse =
type ListImportModelsResponse =
paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json'];
type DeleteImportModelsResponse =
paths['/api/v2/models/import/{id}']['delete']['responses']['201']['content']['application/json'];
type PruneModelImportsResponse =
paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json'];
type AddMainModelArg = {
body: MainModelConfig;
};
type AddMainModelResponse = paths['/api/v2/models/add']['post']['responses']['201']['content']['application/json'];
type SyncModelsResponse = paths['/api/v2/models/sync']['patch']['responses']['204']['content']['application/json'];
type SyncModelsResponse = paths['/api/v2/models/sync']['patch']['responses']['204']['content']
export type SearchFolderResponse =
paths['/api/v2/models/search']['get']['responses']['200']['content']['application/json'];
@ -308,7 +312,25 @@ export const modelsApi = api.injectEndpoints({
url: buildModelsUrl(`import`),
};
},
providesTags: ['ModelImports']
providesTags: ['ModelImports'],
}),
deleteModelImport: build.mutation<DeleteImportModelsResponse, DeleteMainModelArg>({
query: ({ key }) => {
return {
url: buildModelsUrl(`import/${key}`),
method: 'DELETE',
};
},
invalidatesTags: ['ModelImports'],
}),
pruneModelImports: build.mutation<PruneModelImportsResponse, void>({
query: () => {
return {
url: buildModelsUrl('import'),
method: 'PATCH',
};
},
invalidatesTags: ['ModelImports'],
}),
getCheckpointConfigs: build.query<CheckpointConfigsResponse, void>({
query: () => {
@ -339,6 +361,8 @@ export const {
useGetModelsInFolderQuery,
useGetCheckpointConfigsQuery,
useGetModelImportsQuery,
useGetModelMetadataQuery,
useDeleteModelImportMutation,
usePruneModelImportsMutation,
useGetModelQuery,
useGetModelMetadataQuery
} = modelsApi;

View File

@ -9,11 +9,11 @@ import type {
InvocationErrorEvent,
InvocationRetrievalErrorEvent,
InvocationStartedEvent,
ModelInstallCompletedEvent,
ModelInstallDownloadingEvent,
ModelInstallErrorEvent,
ModelLoadCompletedEvent,
ModelLoadStartedEvent,
ModelInstallDownloadingEvent,
ModelInstallCompletedEvent,
ModelInstallErrorEvent,
QueueItemStatusChangedEvent,
SessionRetrievalErrorEvent,
} from 'services/events/types';

View File

@ -51,12 +51,14 @@ export type ModelInstallDownloadingEvent = {
source: string;
timestamp: number;
total_bytes: string;
id: number;
};
export type ModelInstallCompletedEvent = {
key: number;
source: string;
timestamp: number;
id: number;
};
export type ModelInstallErrorEvent = {
@ -64,6 +66,7 @@ export type ModelInstallErrorEvent = {
error_type: string;
source: string;
timestamp: number;
id: number;
};
/**

View File

@ -16,11 +16,11 @@ import {
socketInvocationError,
socketInvocationRetrievalError,
socketInvocationStarted,
socketModelInstallCompleted,
socketModelInstallDownloading,
socketModelInstallError,
socketModelLoadCompleted,
socketModelLoadStarted,
socketModelInstallDownloading,
socketModelInstallCompleted,
socketModelInstallError,
socketQueueItemStatusChanged,
socketSessionRetrievalError,
} from 'services/events/actions';