mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): handle new model_install_download_started
event
When a model install is initiated from outside the client, we now trigger the model manager tab's model install list to update. - Handle new `model_install_download_started` event - Handle `model_install_download_complete` event (this event is not new but was never handled) - Update optimistic updates/cache invalidation logic to efficiently update the model install list
This commit is contained in:
parent
56771de856
commit
f002bca2fa
@ -5,6 +5,8 @@ import {
|
|||||||
socketModelInstallCancelled,
|
socketModelInstallCancelled,
|
||||||
socketModelInstallComplete,
|
socketModelInstallComplete,
|
||||||
socketModelInstallDownloadProgress,
|
socketModelInstallDownloadProgress,
|
||||||
|
socketModelInstallDownloadsComplete,
|
||||||
|
socketModelInstallDownloadStarted,
|
||||||
socketModelInstallError,
|
socketModelInstallError,
|
||||||
socketModelInstallStarted,
|
socketModelInstallStarted,
|
||||||
} from 'services/events/actions';
|
} from 'services/events/actions';
|
||||||
@ -14,9 +16,12 @@ import {
|
|||||||
* which is a bit misleading. For example, a `model_install_started` event is actually fired _after_ the model has fully
|
* which is a bit misleading. For example, a `model_install_started` event is actually fired _after_ the model has fully
|
||||||
* downloaded and is being "physically" installed.
|
* downloaded and is being "physically" installed.
|
||||||
*
|
*
|
||||||
|
* Note: the download events are only fired for remote model installs, not local.
|
||||||
|
*
|
||||||
* Here's the expected flow:
|
* Here's the expected flow:
|
||||||
* - Model manager does some prep
|
* - API receives install request, model manager preps the install
|
||||||
* - `model_install_download_progress` fired when the download starts and continually until the download is complete
|
* - `model_install_download_started` fired when the download starts
|
||||||
|
* - `model_install_download_progress` fired continually until the download is complete
|
||||||
* - `model_install_download_complete` fired when the download is complete
|
* - `model_install_download_complete` fired when the download is complete
|
||||||
* - `model_install_started` fired when the "physical" installation starts
|
* - `model_install_started` fired when the "physical" installation starts
|
||||||
* - `model_install_complete` fired when the installation is complete
|
* - `model_install_complete` fired when the installation is complete
|
||||||
@ -24,47 +29,98 @@ import {
|
|||||||
* - `model_install_error` fired if the installation has an error
|
* - `model_install_error` fired if the installation has an error
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
const selectModelInstalls = modelsApi.endpoints.listModelInstalls.select();
|
||||||
|
|
||||||
export const addModelInstallEventListener = (startAppListening: AppStartListening) => {
|
export const addModelInstallEventListener = (startAppListening: AppStartListening) => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: socketModelInstallDownloadStarted,
|
||||||
|
effect: async (action, { dispatch, getState }) => {
|
||||||
|
const { id } = action.payload.data;
|
||||||
|
const { data } = selectModelInstalls(getState());
|
||||||
|
|
||||||
|
if (!data || !data.find((m) => m.id === id)) {
|
||||||
|
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||||
|
} else {
|
||||||
|
dispatch(
|
||||||
|
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||||
|
const modelImport = draft.find((m) => m.id === id);
|
||||||
|
if (modelImport) {
|
||||||
|
modelImport.status = 'downloading';
|
||||||
|
}
|
||||||
|
return draft;
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: socketModelInstallStarted,
|
actionCreator: socketModelInstallStarted,
|
||||||
effect: async (action, { dispatch }) => {
|
effect: async (action, { dispatch, getState }) => {
|
||||||
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
const { id } = action.payload.data;
|
||||||
|
const { data } = selectModelInstalls(getState());
|
||||||
|
|
||||||
|
if (!data || !data.find((m) => m.id === id)) {
|
||||||
|
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||||
|
} else {
|
||||||
|
dispatch(
|
||||||
|
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||||
|
const modelImport = draft.find((m) => m.id === id);
|
||||||
|
if (modelImport) {
|
||||||
|
modelImport.status = 'running';
|
||||||
|
}
|
||||||
|
return draft;
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: socketModelInstallDownloadProgress,
|
actionCreator: socketModelInstallDownloadProgress,
|
||||||
effect: async (action, { dispatch }) => {
|
effect: async (action, { dispatch, getState }) => {
|
||||||
const { bytes, total_bytes, id } = action.payload.data;
|
const { bytes, total_bytes, id } = action.payload.data;
|
||||||
|
const { data } = selectModelInstalls(getState());
|
||||||
|
|
||||||
dispatch(
|
if (!data || !data.find((m) => m.id === id)) {
|
||||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||||
const modelImport = draft.find((m) => m.id === id);
|
} else {
|
||||||
if (modelImport) {
|
dispatch(
|
||||||
modelImport.bytes = bytes;
|
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||||
modelImport.total_bytes = total_bytes;
|
const modelImport = draft.find((m) => m.id === id);
|
||||||
modelImport.status = 'downloading';
|
if (modelImport) {
|
||||||
}
|
modelImport.bytes = bytes;
|
||||||
return draft;
|
modelImport.total_bytes = total_bytes;
|
||||||
})
|
modelImport.status = 'downloading';
|
||||||
);
|
}
|
||||||
|
return draft;
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: socketModelInstallComplete,
|
actionCreator: socketModelInstallComplete,
|
||||||
effect: (action, { dispatch }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const { id } = action.payload.data;
|
const { id } = action.payload.data;
|
||||||
|
|
||||||
dispatch(
|
const { data } = selectModelInstalls(getState());
|
||||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
|
||||||
const modelImport = draft.find((m) => m.id === id);
|
if (!data || !data.find((m) => m.id === id)) {
|
||||||
if (modelImport) {
|
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||||
modelImport.status = 'completed';
|
} else {
|
||||||
}
|
dispatch(
|
||||||
return draft;
|
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||||
})
|
const modelImport = draft.find((m) => m.id === id);
|
||||||
);
|
if (modelImport) {
|
||||||
|
modelImport.status = 'completed';
|
||||||
|
}
|
||||||
|
return draft;
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }]));
|
dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }]));
|
||||||
dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }]));
|
dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }]));
|
||||||
},
|
},
|
||||||
@ -72,37 +128,69 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
|||||||
|
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: socketModelInstallError,
|
actionCreator: socketModelInstallError,
|
||||||
effect: (action, { dispatch }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const { id, error, error_type } = action.payload.data;
|
const { id, error, error_type } = action.payload.data;
|
||||||
|
const { data } = selectModelInstalls(getState());
|
||||||
|
|
||||||
dispatch(
|
if (!data || !data.find((m) => m.id === id)) {
|
||||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||||
const modelImport = draft.find((m) => m.id === id);
|
} else {
|
||||||
if (modelImport) {
|
dispatch(
|
||||||
modelImport.status = 'error';
|
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||||
modelImport.error_reason = error_type;
|
const modelImport = draft.find((m) => m.id === id);
|
||||||
modelImport.error = error;
|
if (modelImport) {
|
||||||
}
|
modelImport.status = 'error';
|
||||||
return draft;
|
modelImport.error_reason = error_type;
|
||||||
})
|
modelImport.error = error;
|
||||||
);
|
}
|
||||||
|
return draft;
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: socketModelInstallCancelled,
|
actionCreator: socketModelInstallCancelled,
|
||||||
effect: (action, { dispatch }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const { id } = action.payload.data;
|
const { id } = action.payload.data;
|
||||||
|
const { data } = selectModelInstalls(getState());
|
||||||
|
|
||||||
dispatch(
|
if (!data || !data.find((m) => m.id === id)) {
|
||||||
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||||
const modelImport = draft.find((m) => m.id === id);
|
} else {
|
||||||
if (modelImport) {
|
dispatch(
|
||||||
modelImport.status = 'cancelled';
|
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||||
}
|
const modelImport = draft.find((m) => m.id === id);
|
||||||
return draft;
|
if (modelImport) {
|
||||||
})
|
modelImport.status = 'cancelled';
|
||||||
);
|
}
|
||||||
|
return draft;
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: socketModelInstallDownloadsComplete,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
const { id } = action.payload.data;
|
||||||
|
const { data } = selectModelInstalls(getState());
|
||||||
|
|
||||||
|
if (!data || !data.find((m) => m.id === id)) {
|
||||||
|
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
|
||||||
|
} else {
|
||||||
|
dispatch(
|
||||||
|
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||||
|
const modelImport = draft.find((m) => m.id === id);
|
||||||
|
if (modelImport) {
|
||||||
|
modelImport.status = 'downloads_done';
|
||||||
|
}
|
||||||
|
return draft;
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user