feat(ui): handle new model_install_download_started event

When a model install is initiated from outside the client, we now trigger the model manager tab's model install list to update.

- Handle new `model_install_download_started` event
- Handle `model_install_download_complete` event (this event is not new but was never handled)
- Update optimistic updates/cache invalidation logic to efficiently update the model install list
This commit is contained in:
psychedelicious 2024-06-17 10:07:10 +10:00
parent 56771de856
commit f002bca2fa

View File

@ -5,6 +5,8 @@ import {
socketModelInstallCancelled, socketModelInstallCancelled,
socketModelInstallComplete, socketModelInstallComplete,
socketModelInstallDownloadProgress, socketModelInstallDownloadProgress,
socketModelInstallDownloadsComplete,
socketModelInstallDownloadStarted,
socketModelInstallError, socketModelInstallError,
socketModelInstallStarted, socketModelInstallStarted,
} from 'services/events/actions'; } from 'services/events/actions';
@ -14,9 +16,12 @@ import {
* which is a bit misleading. For example, a `model_install_started` event is actually fired _after_ the model has fully * which is a bit misleading. For example, a `model_install_started` event is actually fired _after_ the model has fully
* downloaded and is being "physically" installed. * downloaded and is being "physically" installed.
* *
* Note: the download events are only fired for remote model installs, not local.
*
* Here's the expected flow: * Here's the expected flow:
* - Model manager does some prep * - API receives install request, model manager preps the install
* - `model_install_download_progress` fired when the download starts and continually until the download is complete * - `model_install_download_started` fired when the download starts
* - `model_install_download_progress` fired continually until the download is complete
* - `model_install_download_complete` fired when the download is complete * - `model_install_download_complete` fired when the download is complete
* - `model_install_started` fired when the "physical" installation starts * - `model_install_started` fired when the "physical" installation starts
* - `model_install_complete` fired when the installation is complete * - `model_install_complete` fired when the installation is complete
@ -24,47 +29,98 @@ import {
* - `model_install_error` fired if the installation has an error * - `model_install_error` fired if the installation has an error
*/ */
const selectModelInstalls = modelsApi.endpoints.listModelInstalls.select();
export const addModelInstallEventListener = (startAppListening: AppStartListening) => { export const addModelInstallEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketModelInstallDownloadStarted,
effect: async (action, { dispatch, getState }) => {
const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'downloading';
}
return draft;
})
);
}
},
});
startAppListening({ startAppListening({
actionCreator: socketModelInstallStarted, actionCreator: socketModelInstallStarted,
effect: async (action, { dispatch }) => { effect: async (action, { dispatch, getState }) => {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'running';
}
return draft;
})
);
}
}, },
}); });
startAppListening({ startAppListening({
actionCreator: socketModelInstallDownloadProgress, actionCreator: socketModelInstallDownloadProgress,
effect: async (action, { dispatch }) => { effect: async (action, { dispatch, getState }) => {
const { bytes, total_bytes, id } = action.payload.data; const { bytes, total_bytes, id } = action.payload.data;
const { data } = selectModelInstalls(getState());
dispatch( if (!data || !data.find((m) => m.id === id)) {
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
const modelImport = draft.find((m) => m.id === id); } else {
if (modelImport) { dispatch(
modelImport.bytes = bytes; modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
modelImport.total_bytes = total_bytes; const modelImport = draft.find((m) => m.id === id);
modelImport.status = 'downloading'; if (modelImport) {
} modelImport.bytes = bytes;
return draft; modelImport.total_bytes = total_bytes;
}) modelImport.status = 'downloading';
); }
return draft;
})
);
}
}, },
}); });
startAppListening({ startAppListening({
actionCreator: socketModelInstallComplete, actionCreator: socketModelInstallComplete,
effect: (action, { dispatch }) => { effect: (action, { dispatch, getState }) => {
const { id } = action.payload.data; const { id } = action.payload.data;
dispatch( const { data } = selectModelInstalls(getState());
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id); if (!data || !data.find((m) => m.id === id)) {
if (modelImport) { dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
modelImport.status = 'completed'; } else {
} dispatch(
return draft; modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
}) const modelImport = draft.find((m) => m.id === id);
); if (modelImport) {
modelImport.status = 'completed';
}
return draft;
})
);
}
dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }])); dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }]));
dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }])); dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }]));
}, },
@ -72,37 +128,69 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
startAppListening({ startAppListening({
actionCreator: socketModelInstallError, actionCreator: socketModelInstallError,
effect: (action, { dispatch }) => { effect: (action, { dispatch, getState }) => {
const { id, error, error_type } = action.payload.data; const { id, error, error_type } = action.payload.data;
const { data } = selectModelInstalls(getState());
dispatch( if (!data || !data.find((m) => m.id === id)) {
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
const modelImport = draft.find((m) => m.id === id); } else {
if (modelImport) { dispatch(
modelImport.status = 'error'; modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
modelImport.error_reason = error_type; const modelImport = draft.find((m) => m.id === id);
modelImport.error = error; if (modelImport) {
} modelImport.status = 'error';
return draft; modelImport.error_reason = error_type;
}) modelImport.error = error;
); }
return draft;
})
);
}
}, },
}); });
startAppListening({ startAppListening({
actionCreator: socketModelInstallCancelled, actionCreator: socketModelInstallCancelled,
effect: (action, { dispatch }) => { effect: (action, { dispatch, getState }) => {
const { id } = action.payload.data; const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
dispatch( if (!data || !data.find((m) => m.id === id)) {
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
const modelImport = draft.find((m) => m.id === id); } else {
if (modelImport) { dispatch(
modelImport.status = 'cancelled'; modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
} const modelImport = draft.find((m) => m.id === id);
return draft; if (modelImport) {
}) modelImport.status = 'cancelled';
); }
return draft;
})
);
}
},
});
startAppListening({
actionCreator: socketModelInstallDownloadsComplete,
effect: (action, { dispatch, getState }) => {
const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'downloads_done';
}
return draft;
})
);
}
}, },
}); });
}; };