From 9832139e4e4722f173cae4c54a537daf36a70bea Mon Sep 17 00:00:00 2001
From: ccgauche <gaucheron.laurent@gmail.com>
Date: Tue, 16 Feb 2021 18:52:01 +0100
Subject: [PATCH] New cleaner safer version

---
 common/sys/src/plugin/errors.rs         |   1 +
 common/sys/src/plugin/memory_manager.rs |  66 +++++++++
 common/sys/src/plugin/mod.rs            |   2 +
 common/sys/src/plugin/module.rs         | 182 ++++--------------------
 common/sys/src/plugin/wasm_env.rs       |  48 +++++++
 common/sys/src/plugin/working.rs        |  28 ++++
 plugin/derive/src/lib.rs                |   2 +-
 plugin/rt/src/lib.rs                    |   8 +-
 8 files changed, 177 insertions(+), 160 deletions(-)
 create mode 100644 common/sys/src/plugin/memory_manager.rs
 create mode 100644 common/sys/src/plugin/wasm_env.rs
 create mode 100644 common/sys/src/plugin/working.rs

diff --git a/common/sys/src/plugin/errors.rs b/common/sys/src/plugin/errors.rs
index 21d33012f1..0529c5b5c0 100644
--- a/common/sys/src/plugin/errors.rs
+++ b/common/sys/src/plugin/errors.rs
@@ -25,6 +25,7 @@ pub enum PluginModuleError {
 
 #[derive(Debug)]
 pub enum MemoryAllocationError {
+    InvalidReturnType,
     AllocatorNotFound(ExportError),
     CantAllocate(RuntimeError),
 }
diff --git a/common/sys/src/plugin/memory_manager.rs b/common/sys/src/plugin/memory_manager.rs
new file mode 100644
index 0000000000..b242587a45
--- /dev/null
+++ b/common/sys/src/plugin/memory_manager.rs
@@ -0,0 +1,66 @@
+use std::sync::atomic::{AtomicI32, AtomicU32, Ordering};
+
+use serde::{Serialize, de::DeserializeOwned};
+use wasmer::{Function, Memory, Value};
+
+use super::errors::{MemoryAllocationError, PluginModuleError};
+
+
+pub struct MemoryManager {
+    pub pointer: AtomicI32,
+    pub length: AtomicU32
+}
+
+impl MemoryManager {
+
+    pub fn new() -> Self{
+        Self {
+            pointer: AtomicI32::new(0),
+            length: AtomicU32::new(0),
+        }
+    }
+
+    // This function check if the buffer is wide enough if not it realloc the buffer calling the `wasm_prepare_buffer` function
+    // Note: There is probably optimizations that can be done using less restrictive ordering
+    pub fn get_pointer(&self, object_length: u32, allocator: &Function) -> Result<i32,MemoryAllocationError> {
+        if self.length.load(Ordering::SeqCst) >= object_length {
+            return Ok(self.pointer.load(Ordering::SeqCst));
+        }
+        let pointer = allocator
+            .call(&[Value::I32(object_length as i32)])
+            .map_err(MemoryAllocationError::CantAllocate)?;
+        let pointer= pointer[0].i32().ok_or(MemoryAllocationError::InvalidReturnType)?;
+        self.length.store(object_length, Ordering::SeqCst);
+        self.pointer.store(pointer, Ordering::SeqCst);
+        Ok(pointer)
+    }
+    
+    // This function writes an object to WASM memory returning a pointer and a length. Will realloc the buffer is not wide enough
+    pub fn write_data<T: Serialize>(&self, memory: &Memory, allocator: &Function ,object: &T) -> Result<(i32,u32),PluginModuleError> {
+        self.write_bytes(memory,allocator,&bincode::serialize(object).map_err(PluginModuleError::Encoding)?)
+    }
+
+    // This function writes an raw bytes to WASM memory returning a pointer and a length. Will realloc the buffer is not wide enough
+    pub fn write_bytes(&self, memory: &Memory, allocator: &Function ,array: &[u8]) -> Result<(i32,u32),PluginModuleError> {
+        let len = array.len();
+        let mem_position = self.get_pointer(len as u32, allocator).map_err(PluginModuleError::MemoryAllocation)? as usize;
+        memory.view()[mem_position..mem_position + len]
+            .iter()
+            .zip(array.iter())
+            .for_each(|(cell, byte)| cell.set(*byte));
+        Ok((mem_position as i32, len as u32))
+    }
+}
+
+// This function read data from memory at a position with the array length and converts it to an object using bincode
+pub fn read_data<T: DeserializeOwned>(memory: &Memory, position: i32, length: u32) -> Result<T, bincode::Error> {
+    bincode::deserialize(&read_bytes(memory,position,length))
+}
+
+// This function read raw bytes from memory at a position with the array length
+pub fn read_bytes(memory: &Memory, position: i32, length: u32) -> Vec<u8> {
+    memory.view()[(position as usize)..(position as usize) + length as usize]
+        .iter()
+        .map(|x| x.get())
+        .collect::<Vec<_>>()
+}
\ No newline at end of file
diff --git a/common/sys/src/plugin/mod.rs b/common/sys/src/plugin/mod.rs
index 383c1192e2..c393dacf13 100644
--- a/common/sys/src/plugin/mod.rs
+++ b/common/sys/src/plugin/mod.rs
@@ -1,5 +1,7 @@
 pub mod errors;
 pub mod module;
+pub mod wasm_env;
+pub mod memory_manager;
 
 use common::assets::ASSETS_PATH;
 use serde::{Deserialize, Serialize};
diff --git a/common/sys/src/plugin/module.rs b/common/sys/src/plugin/module.rs
index 1a780911a2..ae5fb2483a 100644
--- a/common/sys/src/plugin/module.rs
+++ b/common/sys/src/plugin/module.rs
@@ -1,12 +1,12 @@
-use std::{cell::{Cell, RefCell}, cmp::Ordering, collections::HashSet, marker::PhantomData, rc::Rc, sync::{self, Arc, Mutex, atomic::AtomicI32}};
+use std::{collections::HashSet, marker::PhantomData, sync::{Arc, Mutex, atomic::AtomicI32}};
 
 use specs::World;
 use wasmer::{
-    imports, Cranelift, Function, HostEnvInitError, Instance, LazyInit, Memory, MemoryView, Module,
-    Store, Value, WasmerEnv, JIT,
+    imports, Cranelift, Function, Instance, Memory, Module,
+    Store, Value, JIT,
 };
 
-use super::errors::{MemoryAllocationError, PluginError, PluginModuleError};
+use super::{errors::{PluginError, PluginModuleError}, memory_manager::{self, MemoryManager}, wasm_env::HostFunctionEnvironement};
 
 use plugin_api::{Action, Event};
 
@@ -14,8 +14,11 @@ use plugin_api::{Action, Event};
 // This structure represent the WASM State of the plugin.
 pub struct PluginModule {
     ecs: Arc<AtomicI32>,
-    wasm_state: Arc<Mutex<WasmState>>,
+    wasm_state: Arc<Mutex<Instance>>,
+    memory_manager: Arc<MemoryManager>,
     events: HashSet<String>,
+    allocator: Function,
+    memory: Memory,
     name: String,
 }
 
@@ -30,21 +33,8 @@ impl PluginModule {
         let module = Module::new(&store, &wasm_data).expect("Can't compile");
 
         // This is the function imported into the wasm environement
-        fn raw_emit_actions(env: &EmitActionEnv, ptr: u32, len: u32) {
-            let memory: &Memory = if let Some(e) = env.memory.get_ref() {
-                e
-            } else {
-                // This should not be possible but I prefer be safer!
-                tracing::error!("Can't get memory from: `{}` plugin", env.name);
-                return;
-            };
-            let memory: MemoryView<u8> = memory.view();
-
-            let str_slice = &memory[ptr as usize..(ptr + len) as usize];
-
-            let bytes: Vec<u8> = str_slice.iter().map(|x| x.get()).collect();
-
-            handle_actions(match bincode::deserialize(&bytes) {
+        fn raw_emit_actions(env: &HostFunctionEnvironement, ptr: u32, len: u32) {
+            handle_actions(match env.read_data(ptr as i32, len) {
                 Ok(e) => e,
                 Err(e) => {
                     tracing::error!(?e, "Can't decode action");
@@ -52,42 +42,14 @@ impl PluginModule {
                 },
             });
         }
-        
-        fn raw_retreive_action(env: &EmitActionEnv, ptr: u32, len: u32) {
-            let memory: &Memory = if let Some(e) = env.memory.get_ref() {
-                e
-            } else {
-                // This should not be possible but I prefer be safer!
-                tracing::error!("Can't get memory from: `{}` plugin", env.name);
-                return;
-            };
-            let memory: MemoryView<u8> = memory.view();
-
-            let str_slice = &memory[ptr as usize..(ptr + len) as usize];
-
-            let bytes: Vec<u8> = str_slice.iter().map(|x| x.get()).collect();
-
-            let r = env.ecs.load(std::sync::atomic::Ordering::SeqCst);
-            if r == i32::MAX {
-                println!("No ECS availible 1");
-                return;
-            }
-            unsafe {
-                if let Some(t) = (r as *const World).as_ref() {
-                    println!("We have a pointer there");
-                } else {
-                    println!("No ECS availible 2");
-                }
-            }
-        }
 
         let ecs = Arc::new(AtomicI32::new(i32::MAX));
+        let memory_manager = Arc::new(MemoryManager::new());
 
         // Create an import object.
         let import_object = imports! {
             "env" => {
-                "raw_emit_actions" => Function::new_native_with_env(&store, EmitActionEnv::new(name.clone(), ecs.clone()), raw_emit_actions),
-                "raw_retreive_action" => Function::new_native_with_env(&store, EmitActionEnv::new(name.clone(), ecs.clone()), raw_retreive_action),
+                "raw_emit_actions" => Function::new_native_with_env(&store, HostFunctionEnvironement::new(name.clone(), ecs.clone(),memory_manager.clone()), raw_emit_actions),
             }
         };
 
@@ -95,13 +57,16 @@ impl PluginModule {
         let instance = Instance::new(&module, &import_object)
             .map_err(PluginModuleError::InstantiationError)?;
         Ok(Self {
+            memory_manager,
             ecs,
+            memory: instance.exports.get_memory("memory").map_err(PluginModuleError::MemoryUninit)?.clone(),
+            allocator: instance.exports.get_function("wasm_prepare_buffer").map_err(PluginModuleError::MemoryUninit)?.clone(),
             events: instance
                 .exports
                 .iter()
                 .map(|(name, _)| name.to_string())
                 .collect(),
-            wasm_state: Arc::new(Mutex::new(WasmState::new(instance))),
+            wasm_state: Arc::new(Mutex::new(instance)),
             name,
         })
     }
@@ -123,7 +88,7 @@ impl PluginModule {
         self.ecs.store((&ecs) as *const _ as i32, std::sync::atomic::Ordering::SeqCst);
         let bytes = {
             let mut state = self.wasm_state.lock().unwrap();
-            match execute_raw(&mut state, event_name, &request.bytes) {
+            match execute_raw(self,&mut state,event_name,&request.bytes) {
                 Ok(e) => e,
                 Err(e) => return Some(Err(e)),
             }
@@ -133,54 +98,6 @@ impl PluginModule {
     }
 }
 
-/// This is an internal struct used to represent the WASM state when the
-/// emit_action function is called
-#[derive(Clone)]
-struct EmitActionEnv {
-    ecs: Arc<AtomicI32>,
-    memory: LazyInit<Memory>,
-    name: String,
-}
-
-impl EmitActionEnv {
-    fn new(name: String,ecs: Arc<AtomicI32>) -> Self {
-        Self {
-            ecs,
-            memory: LazyInit::new(),
-            name,
-        }
-    }
-}
-
-impl WasmerEnv for EmitActionEnv {
-    fn init_with_instance(&mut self, instance: &Instance) -> Result<(), HostEnvInitError> {
-        let memory = instance.exports.get_memory("memory").unwrap();
-        self.memory.initialize(memory.clone());
-        Ok(())
-    }
-}
-
-pub struct WasmMemoryContext {
-    memory_buffer_size: usize,
-    memory_pointer: i32,
-}
-
-pub struct WasmState {
-    instance: Instance,
-    memory: WasmMemoryContext,
-}
-
-impl WasmState {
-    fn new(instance: Instance) -> Self {
-        Self {
-            instance,
-            memory: WasmMemoryContext {
-                memory_buffer_size: 0,
-                memory_pointer: 0,
-            },
-        }
-    }
-}
 
 // This structure represent a Pre-encoded event object (Useful to avoid
 // reencoding for each module in every plugin)
@@ -207,71 +124,30 @@ impl<T: Event> PreparedEventQuery<T> {
 // an interface to limit unsafe behaviours
 #[allow(clippy::needless_range_loop)]
 fn execute_raw(
-    instance: &mut WasmState,
+    module: &PluginModule,
+    instance: &mut Instance,
     event_name: &str,
     bytes: &[u8],
 ) -> Result<Vec<u8>, PluginModuleError> {
-    let len = bytes.len();
-
-    let mem_position = reserve_wasm_memory_buffer(len, &instance.instance, &mut instance.memory)
-        .map_err(PluginModuleError::MemoryAllocation)? as usize;
-
-    let memory = instance
-        .instance
-        .exports
-        .get_memory("memory")
-        .map_err(PluginModuleError::MemoryUninit)?;
-
-    memory.view()[mem_position..mem_position + len]
-        .iter()
-        .zip(bytes.iter())
-        .for_each(|(cell, byte)| cell.set(*byte));
+    let (mem_position,len) = module.memory_manager.write_bytes(&module.memory, &module.allocator, bytes)?;
 
     let func = instance
-        .instance
         .exports
         .get_function(event_name)
         .map_err(PluginModuleError::MemoryUninit)?;
 
-    let mem_position = func
+    let function_result = func
         .call(&[Value::I32(mem_position as i32), Value::I32(len as i32)])
-        .map_err(PluginModuleError::RunFunction)?[0]
+        .map_err(PluginModuleError::RunFunction)?;
+        
+    let pointer = function_result[0]
         .i32()
-        .ok_or_else(PluginModuleError::InvalidArgumentType)? as usize;
+        .ok_or_else(PluginModuleError::InvalidArgumentType)?;
+    let length = function_result[1]
+        .i32()
+        .ok_or_else(PluginModuleError::InvalidArgumentType)? as u32;
 
-    let view: MemoryView<u8> = memory.view();
-
-    let mut new_len_bytes = [0u8; 4];
-    // TODO: It is probably better to dirrectly make the new_len_bytes
-    for i in 0..4 {
-        new_len_bytes[i] = view.get(i + 1).map(Cell::get).unwrap_or(0);
-    }
-
-    let len = u32::from_ne_bytes(new_len_bytes) as usize;
-
-    Ok(view[mem_position..mem_position + len]
-        .iter()
-        .map(|x| x.get())
-        .collect())
-}
-
-fn reserve_wasm_memory_buffer(
-    size: usize,
-    instance: &Instance,
-    context: &mut WasmMemoryContext,
-) -> Result<i32, MemoryAllocationError> {
-    if context.memory_buffer_size >= size {
-        return Ok(context.memory_pointer);
-    }
-    let pointer = instance
-        .exports
-        .get_function("wasm_prepare_buffer")
-        .map_err(MemoryAllocationError::AllocatorNotFound)?
-        .call(&[Value::I32(size as i32)])
-        .map_err(MemoryAllocationError::CantAllocate)?;
-    context.memory_buffer_size = size;
-    context.memory_pointer = pointer[0].i32().unwrap();
-    Ok(context.memory_pointer)
+    Ok(memory_manager::read_bytes(&module.memory, pointer, length))
 }
 
 fn handle_actions(actions: Vec<Action>) {
diff --git a/common/sys/src/plugin/wasm_env.rs b/common/sys/src/plugin/wasm_env.rs
new file mode 100644
index 0000000000..fc332fc505
--- /dev/null
+++ b/common/sys/src/plugin/wasm_env.rs
@@ -0,0 +1,48 @@
+use std::sync::Arc;
+use std::sync::atomic::AtomicI32;
+
+use serde::{Serialize, de::DeserializeOwned};
+use wasmer::{Function, HostEnvInitError, Instance, LazyInit, Memory, WasmerEnv};
+
+use super::{errors::PluginModuleError, memory_manager::{self, MemoryManager}};
+
+#[derive(Clone)]
+pub struct HostFunctionEnvironement {
+    pub ecs: Arc<AtomicI32>, // This represent the pointer to the ECS object (set to i32::MAX if to ECS is availible)
+    pub memory: LazyInit<Memory>, // This object represent the WASM Memory
+    pub allocator: LazyInit<Function>, // Linked to: wasm_prepare_buffer
+    pub memory_manager: Arc<MemoryManager>, // This object represent the current buffer size and pointer 
+    pub name: String, // This represent the plugin name
+}
+
+impl HostFunctionEnvironement {
+    pub fn new(name: String,ecs: Arc<AtomicI32>,memory_manager: Arc<MemoryManager>) -> Self {
+        Self {
+            memory_manager,
+            ecs,
+            allocator: LazyInit::new(),
+            memory: LazyInit::new(),
+            name,
+        }
+    }
+
+    // This function is a safe interface to WASM memory that writes data to the memory returning a pointer and length
+    pub fn write_data<T: Serialize>(&self, object: &T) -> Result<(i32,u32),PluginModuleError> {
+        self.memory_manager.write_data(self.memory.get_ref().unwrap(), self.allocator.get_ref().unwrap(), object)
+    }
+
+    // This function is a safe interface to WASM memory that reads memory from pointer and length returning an object
+    pub fn read_data<T: DeserializeOwned>(&self, position: i32, length: u32) -> Result<T, bincode::Error> {
+        memory_manager::read_data(self.memory.get_ref().unwrap(), position, length)
+    }
+}
+
+impl WasmerEnv for HostFunctionEnvironement {
+    fn init_with_instance(&mut self, instance: &Instance) -> Result<(), HostEnvInitError> {
+        let memory = instance.exports.get_memory("memory").unwrap();
+        self.memory.initialize(memory.clone());
+        let allocator = instance.exports.get_function("wasm_prepare_buffer").expect("Can't get allocator");
+        self.allocator.initialize(allocator.clone());
+        Ok(())
+    }
+}
\ No newline at end of file
diff --git a/common/sys/src/plugin/working.rs b/common/sys/src/plugin/working.rs
new file mode 100644
index 0000000000..73f8817c3d
--- /dev/null
+++ b/common/sys/src/plugin/working.rs
@@ -0,0 +1,28 @@
+ 
+        fn raw_retreive_action(env: &EmitActionEnv, ptr: u32, len: u32) -> (u32, i32) {
+            let memory: &Memory = if let Some(e) = env.memory.get_ref() {
+                e
+            } else {
+                // This should not be possible but I prefer be safer!
+                tracing::error!("Can't get memory from: `{}` plugin", env.name);
+                return ();
+            };
+            let memory: MemoryView<u8> = memory.view();
+
+            let str_slice = &memory[ptr as usize..(ptr + len) as usize];
+
+            let bytes: Vec<u8> = str_slice.iter().map(|x| x.get()).collect();
+
+            let r = env.ecs.load(std::sync::atomic::Ordering::SeqCst);
+            if r == i32::MAX {
+                println!("No ECS availible 1");
+                return;
+            }
+            unsafe {
+                if let Some(t) = (r as *const World).as_ref() {
+                    println!("We have a pointer there");
+                } else {
+                    println!("No ECS availible 2");
+                }
+            }
+        }
\ No newline at end of file
diff --git a/plugin/derive/src/lib.rs b/plugin/derive/src/lib.rs
index 77ee375937..de4e96bd21 100644
--- a/plugin/derive/src/lib.rs
+++ b/plugin/derive/src/lib.rs
@@ -16,7 +16,7 @@ pub fn event_handler(_args: TokenStream, item: TokenStream) -> TokenStream {
     let out: proc_macro2::TokenStream = quote! {
         #[allow(clippy::unnecessary_wraps)]
         #[no_mangle]
-        pub fn #fn_name(intern__ptr: i32, intern__len: u32) -> i32 {
+        pub fn #fn_name(intern__ptr: i32, intern__len: u32) -> (i32,i32) {
             let input = ::veloren_plugin_rt::read_input(intern__ptr,intern__len).unwrap();
             #[inline]
             fn inner(#fn_args) #fn_return {
diff --git a/plugin/rt/src/lib.rs b/plugin/rt/src/lib.rs
index f9ac93dc58..3a85a9275f 100644
--- a/plugin/rt/src/lib.rs
+++ b/plugin/rt/src/lib.rs
@@ -51,13 +51,9 @@ where
     bincode::deserialize(slice).map_err(|_| "Failed to deserialize function input")
 }
 
-pub fn write_output(value: impl Serialize) -> i32 {
+pub fn write_output(value: impl Serialize) -> (i32,i32) {
     let ret = bincode::serialize(&value).expect("Can't serialize event output");
-    let len = ret.len() as u32;
-    unsafe {
-        ::std::ptr::write(1 as _, len);
-    }
-    ret.as_ptr() as _
+    (ret.as_ptr() as _, ret.len() as _)
 }
 
 static mut BUFFERS: Vec<u8> = Vec::new();