feat(ui): add listeners for model load events

- currently only exposed as DEBUG-level logs
This commit is contained in:
psychedelicious 2023-07-16 02:26:30 +10:00
parent 7b6159f8d6
commit c487166d9c
7 changed files with 150 additions and 2 deletions

View File

@ -88,6 +88,8 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
import { addModelLoadStartedEventListener } from './listeners/socketio/socketModelLoadStarted';
import { addModelLoadCompletedEventListener } from './listeners/socketio/socketModelLoadCompleted';
export const listenerMiddleware = createListenerMiddleware();
@ -177,6 +179,8 @@ addSocketConnectedListener();
addSocketDisconnectedListener();
addSocketSubscribedListener();
addSocketUnsubscribedListener();
addModelLoadStartedEventListener();
addModelLoadCompletedEventListener();
// Session Created
addSessionCreatedPendingListener();

View File

@ -0,0 +1,28 @@
import { log } from 'app/logging/useLogger';
import {
appSocketModelLoadCompleted,
socketModelLoadCompleted,
} from 'services/events/actions';
import { startAppListening } from '../..';
const moduleLog = log.child({ namespace: 'socketio' });
export const addModelLoadCompletedEventListener = () => {
startAppListening({
actionCreator: socketModelLoadCompleted,
effect: (action, { dispatch, getState }) => {
const { model_name, model_type, submodel } = action.payload.data;
let modelString = `${model_type} model: ${model_name}`;
if (submodel) {
modelString = modelString.concat(`, submodel: ${submodel}`);
}
moduleLog.debug(action.payload, `Model load completed (${modelString})`);
// pass along the socket event as an application action
dispatch(appSocketModelLoadCompleted(action.payload));
},
});
};

View File

@ -0,0 +1,28 @@
import { log } from 'app/logging/useLogger';
import {
appSocketModelLoadStarted,
socketModelLoadStarted,
} from 'services/events/actions';
import { startAppListening } from '../..';
const moduleLog = log.child({ namespace: 'socketio' });
export const addModelLoadStartedEventListener = () => {
startAppListening({
actionCreator: socketModelLoadStarted,
effect: (action, { dispatch, getState }) => {
const { model_name, model_type, submodel } = action.payload.data;
let modelString = `${model_type} model: ${model_name}`;
if (submodel) {
modelString = modelString.concat(`, submodel: ${submodel}`);
}
moduleLog.debug(action.payload, `Model load started (${modelString})`);
// pass along the socket event as an application action
dispatch(appSocketModelLoadStarted(action.payload));
},
});
};

View File

@ -28,6 +28,7 @@ export type OffsetPaginatedResults_ImageDTO_ =
// Models
export type ModelType = components['schemas']['ModelType'];
export type SubModelType = components['schemas']['SubModelType'];
export type BaseModelType = components['schemas']['BaseModelType'];
export type MainModelField = components['schemas']['MainModelField'];
export type VAEModelField = components['schemas']['VAEModelField'];

View File

@ -5,6 +5,8 @@ import {
InvocationCompleteEvent,
InvocationErrorEvent,
InvocationStartedEvent,
ModelLoadCompletedEvent,
ModelLoadStartedEvent,
} from 'services/events/types';
// Common socket action payload data
@ -162,3 +164,35 @@ export const socketGeneratorProgress = createAction<
export const appSocketGeneratorProgress = createAction<
BaseSocketPayload & { data: GeneratorProgressEvent }
>('socket/appSocketGeneratorProgress');
/**
* Socket.IO Model Load Started
*
* Do not use. Only for use in middleware.
*/
export const socketModelLoadStarted = createAction<
BaseSocketPayload & { data: ModelLoadStartedEvent }
>('socket/socketModelLoadStarted');
/**
* App-level Model Load Started
*/
export const appSocketModelLoadStarted = createAction<
BaseSocketPayload & { data: ModelLoadStartedEvent }
>('socket/appSocketModelLoadStarted');
/**
* Socket.IO Model Load Started
*
* Do not use. Only for use in middleware.
*/
export const socketModelLoadCompleted = createAction<
BaseSocketPayload & { data: ModelLoadCompletedEvent }
>('socket/socketModelLoadCompleted');
/**
* App-level Model Load Completed
*/
export const appSocketModelLoadCompleted = createAction<
BaseSocketPayload & { data: ModelLoadCompletedEvent }
>('socket/appSocketModelLoadCompleted');

View File

@ -1,5 +1,11 @@
import { O } from 'ts-toolbelt';
import { Graph, GraphExecutionState } from '../api/types';
import {
BaseModelType,
Graph,
GraphExecutionState,
ModelType,
SubModelType,
} from '../api/types';
/**
* A progress image, we get one for each step in the generation
@ -25,6 +31,25 @@ export type BaseNode = {
[key: string]: AnyInvocation[keyof AnyInvocation];
};
export type ModelLoadStartedEvent = {
graph_execution_state_id: string;
model_name: string;
base_model: BaseModelType;
model_type: ModelType;
submodel: SubModelType;
};
export type ModelLoadCompletedEvent = {
graph_execution_state_id: string;
model_name: string;
base_model: BaseModelType;
model_type: ModelType;
submodel: SubModelType;
hash?: string;
location: string;
precision: string;
};
/**
* A `generator_progress` socket.io event.
*
@ -101,6 +126,8 @@ export type ServerToClientEvents = {
graph_execution_state_complete: (
payload: GraphExecutionStateCompleteEvent
) => void;
model_load_started: (payload: ModelLoadStartedEvent) => void;
model_load_completed: (payload: ModelLoadCompletedEvent) => void;
};
export type ClientToServerEvents = {

View File

@ -11,6 +11,8 @@ import {
socketConnected,
socketDisconnected,
socketSubscribed,
socketModelLoadStarted,
socketModelLoadCompleted,
} from '../actions';
import { ClientToServerEvents, ServerToClientEvents } from '../types';
import { Logger } from 'roarr';
@ -44,7 +46,7 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
socketSubscribed({
sessionId,
timestamp: getTimestamp(),
boardId: getState().boards.selectedBoardId,
boardId: getState().gallery.selectedBoardId,
})
);
}
@ -118,4 +120,28 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
})
);
});
/**
* Model load started
*/
socket.on('model_load_started', (data) => {
dispatch(
socketModelLoadStarted({
data,
timestamp: getTimestamp(),
})
);
});
/**
* Model load completed
*/
socket.on('model_load_completed', (data) => {
dispatch(
socketModelLoadCompleted({
data,
timestamp: getTimestamp(),
})
);
});
};