feat: model events (#3786)

[feat(nodes): emit model loading
events](7b6159f8d6)

- remove dependency on having access to a `node` during emits, would
need a bit of additional args passed through the system and I don't
think its necessary at this point. this also allowed us to drop an
extraneous fetching/parsing of the session from db.
- provide the invocation context to all `get_model()` calls, so the
events are able to be emitted
- test all model loading events in the app and confirm socket events are
received

[feat(ui): add listeners for model load
events](c487166d9c)

- currently only exposed as DEBUG-level logs

---

One change I missed in the commit messages is the `ModelInfo` class is
not serializable, so I split out the pieces of information we didn't
already have (hash, location, precision) and added them to the event
payload directly.
This commit is contained in:
blessedcoolant 2023-07-18 12:57:47 +12:00 committed by GitHub
commit 0edb31febd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 174 additions and 39 deletions

View File

@ -57,10 +57,10 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput: def invoke(self, context: InvocationContext) -> CompelOutput:
tokenizer_info = context.services.model_manager.get_model( tokenizer_info = context.services.model_manager.get_model(
**self.clip.tokenizer.dict(), **self.clip.tokenizer.dict(), context=context,
) )
text_encoder_info = context.services.model_manager.get_model( text_encoder_info = context.services.model_manager.get_model(
**self.clip.text_encoder.dict(), **self.clip.text_encoder.dict(), context=context,
) )
def _lora_loader(): def _lora_loader():
@ -82,6 +82,7 @@ class CompelInvocation(BaseInvocation):
model_name=name, model_name=name,
base_model=self.clip.text_encoder.base_model, base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
context=context,
).context.model ).context.model
) )
except ModelNotFoundException: except ModelNotFoundException:

View File

@ -157,13 +157,13 @@ class InpaintInvocation(BaseInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"})) **lora.dict(exclude={"weight"}), context=context,)
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context,)
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) vae_info = context.services.model_manager.get_model(**self.vae.vae.dict(), context=context,)
with vae_info as vae,\ with vae_info as vae,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\

View File

@ -76,7 +76,7 @@ def get_scheduler(
scheduler_name, SCHEDULER_MAP['ddim'] scheduler_name, SCHEDULER_MAP['ddim']
) )
orig_scheduler_info = context.services.model_manager.get_model( orig_scheduler_info = context.services.model_manager.get_model(
**scheduler_info.dict() **scheduler_info.dict(), context=context,
) )
with orig_scheduler_info as orig_scheduler: with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config scheduler_config = orig_scheduler.config
@ -262,6 +262,7 @@ class TextToLatentsInvocation(BaseInvocation):
model_name=control_info.control_model.model_name, model_name=control_info.control_model.model_name,
model_type=ModelType.ControlNet, model_type=ModelType.ControlNet,
base_model=control_info.control_model.base_model, base_model=control_info.control_model.base_model,
context=context,
) )
) )
@ -313,14 +314,14 @@ class TextToLatentsInvocation(BaseInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}) **lora.dict(exclude={"weight"}), context=context,
) )
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict() **self.unet.unet.dict(), context=context,
) )
with ExitStack() as exit_stack,\ with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
@ -403,14 +404,14 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}) **lora.dict(exclude={"weight"}), context=context,
) )
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict() **self.unet.unet.dict(), context=context,
) )
with ExitStack() as exit_stack,\ with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
@ -491,7 +492,7 @@ class LatentsToImageInvocation(BaseInvocation):
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
vae_info = context.services.model_manager.get_model( vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(), **self.vae.vae.dict(), context=context,
) )
with vae_info as vae: with vae_info as vae:
@ -636,7 +637,7 @@ class ImageToLatentsInvocation(BaseInvocation):
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) #vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
vae_info = context.services.model_manager.get_model( vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(), **self.vae.vae.dict(), context=context,
) )
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))

View File

@ -105,8 +105,6 @@ class EventServiceBase:
def emit_model_load_started ( def emit_model_load_started (
self, self,
graph_execution_state_id: str, graph_execution_state_id: str,
node: dict,
source_node_id: str,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
@ -117,8 +115,6 @@ class EventServiceBase:
event_name="model_load_started", event_name="model_load_started",
payload=dict( payload=dict(
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
@ -129,8 +125,6 @@ class EventServiceBase:
def emit_model_load_completed( def emit_model_load_completed(
self, self,
graph_execution_state_id: str, graph_execution_state_id: str,
node: dict,
source_node_id: str,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
@ -142,12 +136,12 @@ class EventServiceBase:
event_name="model_load_completed", event_name="model_load_completed",
payload=dict( payload=dict(
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
submodel=submodel, submodel=submodel,
model_info=model_info, hash=model_info.hash,
location=model_info.location,
precision=str(model_info.precision),
), ),
) )

View File

@ -339,7 +339,6 @@ class ModelManagerService(ModelManagerServiceBase):
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
submodel: Optional[SubModelType] = None, submodel: Optional[SubModelType] = None,
node: Optional[BaseInvocation] = None,
context: Optional[InvocationContext] = None, context: Optional[InvocationContext] = None,
) -> ModelInfo: ) -> ModelInfo:
""" """
@ -347,11 +346,9 @@ class ModelManagerService(ModelManagerServiceBase):
part (such as the vae) of a diffusers mode. part (such as the vae) of a diffusers mode.
""" """
# if we are called from within a node, then we get to emit # we can emit model loading events if we are executing with access to the invocation context
# load start and complete events if context:
if node and context:
self._emit_load_event( self._emit_load_event(
node=node,
context=context, context=context,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
@ -366,9 +363,8 @@ class ModelManagerService(ModelManagerServiceBase):
submodel, submodel,
) )
if node and context: if context:
self._emit_load_event( self._emit_load_event(
node=node,
context=context, context=context,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
@ -510,23 +506,19 @@ class ModelManagerService(ModelManagerServiceBase):
def _emit_load_event( def _emit_load_event(
self, self,
node,
context, context,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
submodel: SubModelType, submodel: Optional[SubModelType] = None,
model_info: Optional[ModelInfo] = None, model_info: Optional[ModelInfo] = None,
): ):
if context.services.queue.is_canceled(context.graph_execution_state_id): if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException() raise CanceledException()
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[node.id]
if model_info: if model_info:
context.services.events.emit_model_load_completed( context.services.events.emit_model_load_completed(
graph_execution_state_id=context.graph_execution_state_id, graph_execution_state_id=context.graph_execution_state_id,
node=node.dict(),
source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
@ -536,8 +528,6 @@ class ModelManagerService(ModelManagerServiceBase):
else: else:
context.services.events.emit_model_load_started( context.services.events.emit_model_load_started(
graph_execution_state_id=context.graph_execution_state_id, graph_execution_state_id=context.graph_execution_state_id,
node=node.dict(),
source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,

View File

@ -88,6 +88,8 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage'; import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes'; import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage'; import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
import { addModelLoadStartedEventListener } from './listeners/socketio/socketModelLoadStarted';
import { addModelLoadCompletedEventListener } from './listeners/socketio/socketModelLoadCompleted';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
@ -177,6 +179,8 @@ addSocketConnectedListener();
addSocketDisconnectedListener(); addSocketDisconnectedListener();
addSocketSubscribedListener(); addSocketSubscribedListener();
addSocketUnsubscribedListener(); addSocketUnsubscribedListener();
addModelLoadStartedEventListener();
addModelLoadCompletedEventListener();
// Session Created // Session Created
addSessionCreatedPendingListener(); addSessionCreatedPendingListener();

View File

@ -0,0 +1,28 @@
import { log } from 'app/logging/useLogger';
import {
appSocketModelLoadCompleted,
socketModelLoadCompleted,
} from 'services/events/actions';
import { startAppListening } from '../..';
const moduleLog = log.child({ namespace: 'socketio' });
export const addModelLoadCompletedEventListener = () => {
startAppListening({
actionCreator: socketModelLoadCompleted,
effect: (action, { dispatch, getState }) => {
const { model_name, model_type, submodel } = action.payload.data;
let modelString = `${model_type} model: ${model_name}`;
if (submodel) {
modelString = modelString.concat(`, submodel: ${submodel}`);
}
moduleLog.debug(action.payload, `Model load completed (${modelString})`);
// pass along the socket event as an application action
dispatch(appSocketModelLoadCompleted(action.payload));
},
});
};

View File

@ -0,0 +1,28 @@
import { log } from 'app/logging/useLogger';
import {
appSocketModelLoadStarted,
socketModelLoadStarted,
} from 'services/events/actions';
import { startAppListening } from '../..';
const moduleLog = log.child({ namespace: 'socketio' });
export const addModelLoadStartedEventListener = () => {
startAppListening({
actionCreator: socketModelLoadStarted,
effect: (action, { dispatch, getState }) => {
const { model_name, model_type, submodel } = action.payload.data;
let modelString = `${model_type} model: ${model_name}`;
if (submodel) {
modelString = modelString.concat(`, submodel: ${submodel}`);
}
moduleLog.debug(action.payload, `Model load started (${modelString})`);
// pass along the socket event as an application action
dispatch(appSocketModelLoadStarted(action.payload));
},
});
};

View File

@ -3,4 +3,5 @@ import dateFormat from 'dateformat';
/** /**
* Get a `now` timestamp with 1s precision, formatted as ISO datetime. * Get a `now` timestamp with 1s precision, formatted as ISO datetime.
*/ */
export const getTimestamp = () => dateFormat(new Date(), 'isoDateTime'); export const getTimestamp = () =>
dateFormat(new Date(), `yyyy-mm-dd'T'HH:MM:ss:lo`);

View File

@ -28,6 +28,7 @@ export type OffsetPaginatedResults_ImageDTO_ =
// Models // Models
export type ModelType = components['schemas']['ModelType']; export type ModelType = components['schemas']['ModelType'];
export type SubModelType = components['schemas']['SubModelType'];
export type BaseModelType = components['schemas']['BaseModelType']; export type BaseModelType = components['schemas']['BaseModelType'];
export type MainModelField = components['schemas']['MainModelField']; export type MainModelField = components['schemas']['MainModelField'];
export type VAEModelField = components['schemas']['VAEModelField']; export type VAEModelField = components['schemas']['VAEModelField'];

View File

@ -5,6 +5,8 @@ import {
InvocationCompleteEvent, InvocationCompleteEvent,
InvocationErrorEvent, InvocationErrorEvent,
InvocationStartedEvent, InvocationStartedEvent,
ModelLoadCompletedEvent,
ModelLoadStartedEvent,
} from 'services/events/types'; } from 'services/events/types';
// Common socket action payload data // Common socket action payload data
@ -162,3 +164,35 @@ export const socketGeneratorProgress = createAction<
export const appSocketGeneratorProgress = createAction< export const appSocketGeneratorProgress = createAction<
BaseSocketPayload & { data: GeneratorProgressEvent } BaseSocketPayload & { data: GeneratorProgressEvent }
>('socket/appSocketGeneratorProgress'); >('socket/appSocketGeneratorProgress');
/**
* Socket.IO Model Load Started
*
* Do not use. Only for use in middleware.
*/
export const socketModelLoadStarted = createAction<
BaseSocketPayload & { data: ModelLoadStartedEvent }
>('socket/socketModelLoadStarted');
/**
* App-level Model Load Started
*/
export const appSocketModelLoadStarted = createAction<
BaseSocketPayload & { data: ModelLoadStartedEvent }
>('socket/appSocketModelLoadStarted');
/**
* Socket.IO Model Load Started
*
* Do not use. Only for use in middleware.
*/
export const socketModelLoadCompleted = createAction<
BaseSocketPayload & { data: ModelLoadCompletedEvent }
>('socket/socketModelLoadCompleted');
/**
* App-level Model Load Completed
*/
export const appSocketModelLoadCompleted = createAction<
BaseSocketPayload & { data: ModelLoadCompletedEvent }
>('socket/appSocketModelLoadCompleted');

View File

@ -1,5 +1,11 @@
import { O } from 'ts-toolbelt'; import { O } from 'ts-toolbelt';
import { Graph, GraphExecutionState } from '../api/types'; import {
BaseModelType,
Graph,
GraphExecutionState,
ModelType,
SubModelType,
} from '../api/types';
/** /**
* A progress image, we get one for each step in the generation * A progress image, we get one for each step in the generation
@ -25,6 +31,25 @@ export type BaseNode = {
[key: string]: AnyInvocation[keyof AnyInvocation]; [key: string]: AnyInvocation[keyof AnyInvocation];
}; };
export type ModelLoadStartedEvent = {
graph_execution_state_id: string;
model_name: string;
base_model: BaseModelType;
model_type: ModelType;
submodel: SubModelType;
};
export type ModelLoadCompletedEvent = {
graph_execution_state_id: string;
model_name: string;
base_model: BaseModelType;
model_type: ModelType;
submodel: SubModelType;
hash?: string;
location: string;
precision: string;
};
/** /**
* A `generator_progress` socket.io event. * A `generator_progress` socket.io event.
* *
@ -101,6 +126,8 @@ export type ServerToClientEvents = {
graph_execution_state_complete: ( graph_execution_state_complete: (
payload: GraphExecutionStateCompleteEvent payload: GraphExecutionStateCompleteEvent
) => void; ) => void;
model_load_started: (payload: ModelLoadStartedEvent) => void;
model_load_completed: (payload: ModelLoadCompletedEvent) => void;
}; };
export type ClientToServerEvents = { export type ClientToServerEvents = {

View File

@ -11,6 +11,8 @@ import {
socketConnected, socketConnected,
socketDisconnected, socketDisconnected,
socketSubscribed, socketSubscribed,
socketModelLoadStarted,
socketModelLoadCompleted,
} from '../actions'; } from '../actions';
import { ClientToServerEvents, ServerToClientEvents } from '../types'; import { ClientToServerEvents, ServerToClientEvents } from '../types';
import { Logger } from 'roarr'; import { Logger } from 'roarr';
@ -44,7 +46,7 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
socketSubscribed({ socketSubscribed({
sessionId, sessionId,
timestamp: getTimestamp(), timestamp: getTimestamp(),
boardId: getState().boards.selectedBoardId, boardId: getState().gallery.selectedBoardId,
}) })
); );
} }
@ -118,4 +120,28 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
}) })
); );
}); });
/**
* Model load started
*/
socket.on('model_load_started', (data) => {
dispatch(
socketModelLoadStarted({
data,
timestamp: getTimestamp(),
})
);
});
/**
* Model load completed
*/
socket.on('model_load_completed', (data) => {
dispatch(
socketModelLoadCompleted({
data,
timestamp: getTimestamp(),
})
);
});
}; };