mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): handle new model_install_download_started
event
When a model install is initiated from outside the client, we now trigger the model manager tab's model install list to update. - Handle new `model_install_download_started` event - Handle `model_install_download_complete` event (this event is not new but was never handled) - Update optimistic updates/cache invalidation logic to efficiently update the model install list
This commit is contained in:
parent
56771de856
commit
f002bca2fa
@ -5,6 +5,8 @@ import {
|
||||
socketModelInstallCancelled,
|
||||
socketModelInstallComplete,
|
||||
socketModelInstallDownloadProgress,
|
||||
socketModelInstallDownloadsComplete,
|
||||
socketModelInstallDownloadStarted,
|
||||
socketModelInstallError,
|
||||
socketModelInstallStarted,
|
||||
} from 'services/events/actions';
|
||||
@ -14,9 +16,12 @@ import {
|
||||
* which is a bit misleading. For example, a `model_install_started` event is actually fired _after_ the model has fully
|
||||
* downloaded and is being "physically" installed.
|
||||
*
|
||||
* Note: the download events are only fired for remote model installs, not local.
|
||||
*
|
||||
* Here's the expected flow:
|
||||
* - Model manager does some prep
|
||||
* - `model_install_download_progress` fired when the download starts and continually until the download is complete
|
||||
* - API receives install request, model manager preps the install
|
||||
* - `model_install_download_started` fired when the download starts
|
||||
* - `model_install_download_progress` fired continually until the download is complete
|
||||
* - `model_install_download_complete` fired when the download is complete
|
||||
* - `model_install_started` fired when the "physical" installation starts
|
||||
* - `model_install_complete` fired when the installation is complete
|
||||
@ -24,47 +29,98 @@ import {
|
||||
* - `model_install_error` fired if the installation has an error
|
||||
*/
|
||||
|
||||
const selectModelInstalls = modelsApi.endpoints.listModelInstalls.select();
|
||||
|
||||
export const addModelInstallEventListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: socketModelInstallDownloadStarted,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const { id } = action.payload.data;
|
||||
const { data } = selectModelInstalls(getState());
|
||||
|
||||
if (!data || !data.find((m) => m.id === id)) {
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||
} else {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'downloading';
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketModelInstallStarted,
|
||||
effect: async (action, { dispatch }) => {
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const { id } = action.payload.data;
|
||||
const { data } = selectModelInstalls(getState());
|
||||
|
||||
if (!data || !data.find((m) => m.id === id)) {
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||
} else {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'running';
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketModelInstallDownloadProgress,
|
||||
effect: async (action, { dispatch }) => {
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const { bytes, total_bytes, id } = action.payload.data;
|
||||
const { data } = selectModelInstalls(getState());
|
||||
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.bytes = bytes;
|
||||
modelImport.total_bytes = total_bytes;
|
||||
modelImport.status = 'downloading';
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
if (!data || !data.find((m) => m.id === id)) {
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||
} else {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.bytes = bytes;
|
||||
modelImport.total_bytes = total_bytes;
|
||||
modelImport.status = 'downloading';
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketModelInstallComplete,
|
||||
effect: (action, { dispatch }) => {
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const { id } = action.payload.data;
|
||||
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'completed';
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
const { data } = selectModelInstalls(getState());
|
||||
|
||||
if (!data || !data.find((m) => m.id === id)) {
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||
} else {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'completed';
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }]));
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }]));
|
||||
},
|
||||
@ -72,37 +128,69 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketModelInstallError,
|
||||
effect: (action, { dispatch }) => {
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const { id, error, error_type } = action.payload.data;
|
||||
const { data } = selectModelInstalls(getState());
|
||||
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'error';
|
||||
modelImport.error_reason = error_type;
|
||||
modelImport.error = error;
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
if (!data || !data.find((m) => m.id === id)) {
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||
} else {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'error';
|
||||
modelImport.error_reason = error_type;
|
||||
modelImport.error = error;
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketModelInstallCancelled,
|
||||
effect: (action, { dispatch }) => {
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const { id } = action.payload.data;
|
||||
const { data } = selectModelInstalls(getState());
|
||||
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'cancelled';
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
if (!data || !data.find((m) => m.id === id)) {
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||
} else {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'cancelled';
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketModelInstallDownloadsComplete,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const { id } = action.payload.data;
|
||||
const { data } = selectModelInstalls(getState());
|
||||
|
||||
if (!data || !data.find((m) => m.id === id)) {
|
||||
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||
} else {
|
||||
dispatch(
|
||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||
const modelImport = draft.find((m) => m.id === id);
|
||||
if (modelImport) {
|
||||
modelImport.status = 'downloads_done';
|
||||
}
|
||||
return draft;
|
||||
})
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user