feat(ui): handle new model_install_download_started event

When a model install is initiated from outside the client, we now trigger the model manager tab's model install list to update.

- Handle new `model_install_download_started` event
- Handle `model_install_download_complete` event (this event is not new but was never handled)
- Update optimistic updates/cache invalidation logic to efficiently update the model install list
This commit is contained in:
psychedelicious 2024-06-17 10:07:10 +10:00
parent 56771de856
commit f002bca2fa

View File

@ -5,6 +5,8 @@ import {
socketModelInstallCancelled, socketModelInstallCancelled,
socketModelInstallComplete, socketModelInstallComplete,
socketModelInstallDownloadProgress, socketModelInstallDownloadProgress,
socketModelInstallDownloadsComplete,
socketModelInstallDownloadStarted,
socketModelInstallError, socketModelInstallError,
socketModelInstallStarted, socketModelInstallStarted,
} from 'services/events/actions'; } from 'services/events/actions';
@ -14,9 +16,12 @@ import {
* which is a bit misleading. For example, a `model_install_started` event is actually fired _after_ the model has fully * which is a bit misleading. For example, a `model_install_started` event is actually fired _after_ the model has fully
* downloaded and is being "physically" installed. * downloaded and is being "physically" installed.
* *
* Note: the download events are only fired for remote model installs, not local.
*
* Here's the expected flow: * Here's the expected flow:
* - Model manager does some prep * - API receives install request, model manager preps the install
* - `model_install_download_progress` fired when the download starts and continually until the download is complete * - `model_install_download_started` fired when the download starts
* - `model_install_download_progress` fired continually until the download is complete
* - `model_install_download_complete` fired when the download is complete * - `model_install_download_complete` fired when the download is complete
* - `model_install_started` fired when the "physical" installation starts * - `model_install_started` fired when the "physical" installation starts
* - `model_install_complete` fired when the installation is complete * - `model_install_complete` fired when the installation is complete
@ -24,19 +29,62 @@ import {
* - `model_install_error` fired if the installation has an error * - `model_install_error` fired if the installation has an error
*/ */
const selectModelInstalls = modelsApi.endpoints.listModelInstalls.select();
export const addModelInstallEventListener = (startAppListening: AppStartListening) => { export const addModelInstallEventListener = (startAppListening: AppStartListening) => {
startAppListening({ startAppListening({
actionCreator: socketModelInstallStarted, actionCreator: socketModelInstallDownloadStarted,
effect: async (action, { dispatch }) => { effect: async (action, { dispatch, getState }) => {
const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'downloading';
}
return draft;
})
);
}
},
});
startAppListening({
actionCreator: socketModelInstallStarted,
effect: async (action, { dispatch, getState }) => {
const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'running';
}
return draft;
})
);
}
}, },
}); });
startAppListening({ startAppListening({
actionCreator: socketModelInstallDownloadProgress, actionCreator: socketModelInstallDownloadProgress,
effect: async (action, { dispatch }) => { effect: async (action, { dispatch, getState }) => {
const { bytes, total_bytes, id } = action.payload.data; const { bytes, total_bytes, id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch( dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id); const modelImport = draft.find((m) => m.id === id);
@ -48,14 +96,20 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
return draft; return draft;
}) })
); );
}
}, },
}); });
startAppListening({ startAppListening({
actionCreator: socketModelInstallComplete, actionCreator: socketModelInstallComplete,
effect: (action, { dispatch }) => { effect: (action, { dispatch, getState }) => {
const { id } = action.payload.data; const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch( dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id); const modelImport = draft.find((m) => m.id === id);
@ -65,6 +119,8 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
return draft; return draft;
}) })
); );
}
dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }])); dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }]));
dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }])); dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }]));
}, },
@ -72,9 +128,13 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
startAppListening({ startAppListening({
actionCreator: socketModelInstallError, actionCreator: socketModelInstallError,
effect: (action, { dispatch }) => { effect: (action, { dispatch, getState }) => {
const { id, error, error_type } = action.payload.data; const { id, error, error_type } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch( dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id); const modelImport = draft.find((m) => m.id === id);
@ -86,14 +146,19 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
return draft; return draft;
}) })
); );
}
}, },
}); });
startAppListening({ startAppListening({
actionCreator: socketModelInstallCancelled, actionCreator: socketModelInstallCancelled,
effect: (action, { dispatch }) => { effect: (action, { dispatch, getState }) => {
const { id } = action.payload.data; const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch( dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id); const modelImport = draft.find((m) => m.id === id);
@ -103,6 +168,29 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
return draft; return draft;
}) })
); );
}
},
});
startAppListening({
actionCreator: socketModelInstallDownloadsComplete,
effect: (action, { dispatch, getState }) => {
const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'downloads_done';
}
return draft;
})
);
}
}, },
}); });
}; };