Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error hanlding #565

Merged
merged 4 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 40 additions & 24 deletions executor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ mod proof;
mod task;

fn setup_console(level: Option<String>) {
std::panic::set_hook(Box::new(console_error_panic_hook::hook));
console_error_panic_hook::set_once();
let level = level.map(|x| x.to_uppercase()).unwrap_or("INFO".into());
let _ = console_log::init_with_level(log::Level::from_str(level.as_str()).unwrap());
}
Expand All @@ -35,34 +35,41 @@ extern "C" {
#[wasm_bindgen(typescript_type = "JsCallback")]
pub type JsCallback;

#[wasm_bindgen(structural, method, js_name = "getStorage")]
pub async fn get_storage(this: &JsCallback, key: JsValue) -> JsValue;
#[wasm_bindgen(catch, structural, method, js_name = "getStorage")]
pub async fn get_storage(this: &JsCallback, key: JsValue) -> Result<JsValue, JsValue>;

#[wasm_bindgen(structural, method, js_name = "getStateRoot")]
pub async fn get_state_root(this: &JsCallback) -> JsValue;
#[wasm_bindgen(catch, structural, method, js_name = "getStateRoot")]
pub async fn get_state_root(this: &JsCallback) -> Result<JsValue, JsValue>;

#[wasm_bindgen(structural, method, js_name = "getNextKey")]
pub async fn get_next_key(this: &JsCallback, prefix: JsValue, key: JsValue) -> JsValue;
#[wasm_bindgen(catch, structural, method, js_name = "getNextKey")]
pub async fn get_next_key(
this: &JsCallback,
prefix: JsValue,
key: JsValue,
) -> Result<JsValue, JsValue>;

#[wasm_bindgen(structural, method, js_name = "offchainGetStorage")]
pub async fn offchain_get_storage(this: &JsCallback, key: JsValue) -> JsValue;
#[wasm_bindgen(catch, structural, method, js_name = "offchainGetStorage")]
pub async fn offchain_get_storage(this: &JsCallback, key: JsValue) -> Result<JsValue, JsValue>;

#[wasm_bindgen(structural, method, js_name = "offchainTimestamp")]
pub async fn offchain_timestamp(this: &JsCallback) -> JsValue;
#[wasm_bindgen(catch, structural, method, js_name = "offchainTimestamp")]
pub async fn offchain_timestamp(this: &JsCallback) -> Result<JsValue, JsValue>;

#[wasm_bindgen(structural, method, js_name = "offchainRandomSeed")]
pub async fn offchain_random_seed(this: &JsCallback) -> JsValue;
#[wasm_bindgen(catch, structural, method, js_name = "offchainRandomSeed")]
pub async fn offchain_random_seed(this: &JsCallback) -> Result<JsValue, JsValue>;

#[wasm_bindgen(structural, method, js_name = "offchainSubmitTransaction")]
pub async fn offchain_submit_transaction(this: &JsCallback, tx: JsValue) -> JsValue;
#[wasm_bindgen(catch, structural, method, js_name = "offchainSubmitTransaction")]
pub async fn offchain_submit_transaction(
this: &JsCallback,
tx: JsValue,
) -> Result<JsValue, JsValue>;
}

#[wasm_bindgen]
pub async fn get_runtime_version(code: JsValue) -> Result<JsValue, JsValue> {
pub async fn get_runtime_version(code: JsValue) -> Result<JsValue, JsError> {
setup_console(None);

let code = serde_wasm_bindgen::from_value::<HexString>(code)?;
let runtime_version = task::runtime_version(code).await?;
let runtime_version = task::runtime_version(code).await;
let result = serde_wasm_bindgen::to_value(&runtime_version)?;

Ok(result)
Expand All @@ -72,13 +79,13 @@ pub async fn get_runtime_version(code: JsValue) -> Result<JsValue, JsValue> {
pub async fn calculate_state_root(
entries: JsValue,
trie_version: JsValue,
) -> Result<JsValue, JsValue> {
) -> Result<JsValue, JsError> {
setup_console(None);

let entries = serde_wasm_bindgen::from_value::<Vec<(HexString, HexString)>>(entries)?;
let trie_version = serde_wasm_bindgen::from_value::<u8>(trie_version)?;
let trie_version =
TrieEntryVersion::try_from(trie_version).map_err(|_| "invalid trie version")?;
let trie_version = TrieEntryVersion::try_from(trie_version)
.map_err(|_| JsError::new("invalid trie version"))?;
let hash = task::calculate_state_root(entries, trie_version);
let result = serde_wasm_bindgen::to_value(&hash)?;

Expand All @@ -90,7 +97,7 @@ pub async fn decode_proof(
root_trie_hash: JsValue,
keys: JsValue,
nodes: JsValue,
) -> Result<JsValue, JsValue> {
) -> Result<JsValue, JsError> {
setup_console(None);

let root_trie_hash = serde_wasm_bindgen::from_value::<HashHexString>(root_trie_hash)?;
Expand All @@ -100,14 +107,15 @@ pub async fn decode_proof(
root_trie_hash,
keys,
nodes.into_iter().map(|x| x.0).collect(),
)?;
)
.map_err(|e| JsError::new(e.as_str()))?;
let result = serde_wasm_bindgen::to_value(&entries)?;

Ok(result)
}

#[wasm_bindgen]
pub async fn create_proof(nodes: JsValue, entries: JsValue) -> Result<JsValue, JsValue> {
pub async fn create_proof(nodes: JsValue, entries: JsValue) -> Result<JsValue, JsError> {
setup_console(None);

let proof = serde_wasm_bindgen::from_value::<Vec<HexString>>(nodes)?;
Expand All @@ -117,7 +125,8 @@ pub async fn create_proof(nodes: JsValue, entries: JsValue) -> Result<JsValue, J
.into_iter()
.map(|(key, value)| (key.0, value.map(|x| x.0))),
);
let proof = proof::create_proof(proof.into_iter().map(|x| x.0).collect(), entries)?;
let proof = proof::create_proof(proof.into_iter().map(|x| x.0).collect(), entries)
.map_err(|e| JsError::new(e.as_str()))?;
let result = serde_wasm_bindgen::to_value(&proof)?;

Ok(result)
Expand All @@ -137,3 +146,10 @@ pub async fn run_task(

Ok(result)
}

#[wasm_bindgen]
pub async fn testing(js: JsCallback, key: JsValue) -> Result<JsValue, JsValue> {
setup_console(None);

js.get_storage(key).await
}
49 changes: 24 additions & 25 deletions executor/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use smoldot::{
verify::body_only::LogEmitInfo,
};
use std::collections::BTreeMap;
use wasm_bindgen::prelude::*;

#[derive(Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
Expand Down Expand Up @@ -108,18 +109,18 @@ fn prefixed_child_key(child: impl Iterator<Item = u8>, key: impl Iterator<Item =
.concat()
}

fn handle_value(value: wasm_bindgen::JsValue) -> Result<Option<Vec<u8>>, String> {
fn handle_value(value: wasm_bindgen::JsValue) -> Result<Option<Vec<u8>>, JsError> {
if value.is_string() {
let encoded = from_value::<HexString>(value)
.map(|x| x.0)
.map_err(|e| e.to_string())?;
?;
Ok(Some(encoded))
} else {
Ok(None)
}
}

pub async fn run_task(task: TaskCall, js: crate::JsCallback) -> Result<TaskResponse, String> {
pub async fn run_task(task: TaskCall, js: crate::JsCallback) -> Result<TaskResponse, JsValue> {
let mut storage_main_trie_changes = TrieDiff::default();
let mut storage_changes: BTreeMap<Vec<u8>, Option<Vec<u8>>> = Default::default();
let mut offchain_storage_changes: BTreeMap<Vec<u8>, Option<Vec<u8>>> = Default::default();
Expand Down Expand Up @@ -172,13 +173,12 @@ pub async fn run_task(task: TaskCall, js: crate::JsCallback) -> Result<TaskRespo
)
} else {
// otherwise, ask chopsticks
let key = to_value(&key).map_err(|e| e.to_string())?;
let key = to_value(&key)?;

let value = js.get_storage(key).await;
let value = js.get_storage(key).await?;
let value = if value.is_string() {
let encoded = from_value::<HexString>(value)
.map(|x| x.0)
.map_err(|e| e.to_string())?;
.map(|x| x.0)?;
Some(encoded)
} else {
None
Expand All @@ -188,10 +188,9 @@ pub async fn run_task(task: TaskCall, js: crate::JsCallback) -> Result<TaskRespo
}

RuntimeHostVm::ClosestDescendantMerkleValue(req) => {
let value = js.get_state_root().await;
let value = js.get_state_root().await?;
let value = from_value::<HexString>(value)
.map(|x| x.0)
.map_err(|e| e.to_string())?;
.map(|x| x.0)?;
req.inject_merkle_value(Some(value.as_ref()))
}

Expand All @@ -218,9 +217,9 @@ pub async fn run_task(task: TaskCall, js: crate::JsCallback) -> Result<TaskRespo
} else {
HexString(nibbles_to_bytes_suffix_extend(req.key()).collect::<Vec<_>>())
};
let prefix = to_value(&prefix).map_err(|e| e.to_string())?;
let key = to_value(&key).map_err(|e| e.to_string())?;
let value = js.get_next_key(prefix, key).await;
let prefix = to_value(&prefix)?;
let key = to_value(&key)?;
let value = js.get_next_key(prefix, key).await?;
req.inject_key(
handle_value(value)?.map(|x| bytes_to_nibbles(x.into_iter())),
)
Expand Down Expand Up @@ -248,8 +247,8 @@ pub async fn run_task(task: TaskCall, js: crate::JsCallback) -> Result<TaskRespo
RuntimeHostVm::Offchain(ctx) => match ctx {
OffchainContext::StorageGet(req) => {
let key = HexString(req.key().as_ref().to_vec());
let key = to_value(&key).map_err(|e| e.to_string())?;
let value = js.offchain_get_storage(key).await;
let key = to_value(&key)?;
let value = js.offchain_get_storage(key).await?;
req.inject_value(handle_value(value)?)
}

Expand All @@ -273,26 +272,26 @@ pub async fn run_task(task: TaskCall, js: crate::JsCallback) -> Result<TaskRespo
}

OffchainContext::Timestamp(req) => {
let value = js.offchain_timestamp().await;
let timestamp = from_value::<u64>(value).map_err(|e| e.to_string())?;
let value = js.offchain_timestamp().await?;
let timestamp = from_value::<u64>(value)?;
req.inject_timestamp(timestamp)
}

OffchainContext::RandomSeed(req) => {
let value = js.offchain_random_seed().await;
let random = from_value::<HexString>(value).map_err(|e| e.to_string())?;
let value = js.offchain_random_seed().await?;
let random = from_value::<HexString>(value)?;
let value: [u8; 32] = random
.0
.try_into()
.map_err(|_| "invalid random seed value")?;
.map_err(|_| JsError::new("invalid random seed value"))?;
req.inject_random_seed(value)
}

OffchainContext::SubmitTransaction(req) => {
let tx = HexString(req.transaction().as_ref().to_vec());
let tx = to_value(&tx).map_err(|e| e.to_string())?;
let success = js.offchain_submit_transaction(tx).await;
let success = from_value::<bool>(success).map_err(|e| e.to_string())?;
let tx = to_value(&tx)?;
let success = js.offchain_submit_transaction(tx).await?;
let success = from_value::<bool>(success)?;
req.resume(success)
}
},
Expand Down Expand Up @@ -402,7 +401,7 @@ pub async fn run_task(task: TaskCall, js: crate::JsCallback) -> Result<TaskRespo
}))
}

pub async fn runtime_version(wasm: HexString) -> Result<RuntimeVersion, String> {
pub async fn runtime_version(wasm: HexString) -> RuntimeVersion {
let vm_proto = HostVmPrototype::new(Config {
module: &wasm,
heap_pages: HeapPages::from(2048),
Expand All @@ -413,7 +412,7 @@ pub async fn runtime_version(wasm: HexString) -> Result<RuntimeVersion, String>

let core_version = vm_proto.runtime_version().decode();

Ok(RuntimeVersion::new(core_version))
RuntimeVersion::new(core_version)
}

pub fn calculate_state_root(
Expand Down
6 changes: 5 additions & 1 deletion packages/core/src/wasm-executor/browser-wasm-executor.js
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ const runTask = async (task, callback) => {
return pkg.run_task(task, callback, 'info')
}

const wasmExecutor = { runTask, getRuntimeVersion, calculateStateRoot, createProof, decodeProof }
const testing = async (callback, key) => {
return pkg.testing(callback, key)
}

const wasmExecutor = { runTask, getRuntimeVersion, calculateStateRoot, createProof, decodeProof, testing }

Comlink.expose(wasmExecutor)
36 changes: 33 additions & 3 deletions packages/core/src/wasm-executor/executor.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import * as Comlink from 'comlink'
import { HexString } from '@polkadot/util/types'
import { TypeRegistry } from '@polkadot/types'
import { describe, expect, it } from 'vitest'
import { readFileSync } from 'node:fs'
import _ from 'lodash'
import path from 'node:path'

import {
Expand All @@ -11,13 +13,21 @@ import {
hrmpIngressChannelIndex,
upgradeGoAheadSignal,
} from '../utils/proof.js'
import { calculateStateRoot, createProof, decodeProof, getAuraSlotDuration, getRuntimeVersion } from './index.js'
import {
calculateStateRoot,
createProof,
decodeProof,
emptyTaskHandler,
getAuraSlotDuration,
getRuntimeVersion,
getWorker,
} from './index.js'

const getCode = () => {
const getCode = _.memoize(() => {
const code = String(readFileSync(path.join(__dirname, '../../../e2e/blobs/acala-runtime-2101.txt'))).trim()
expect(code.length).toBeGreaterThan(2)
return code as HexString
}
})

describe('wasm', () => {
it('get runtime version from wasm runtime', async () => {
Expand Down Expand Up @@ -147,4 +157,24 @@ describe('wasm', () => {
const slotDuration = await getAuraSlotDuration(getCode())
expect(slotDuration).eq(12000)
})

it('handles panic', async () => {
const worker = await getWorker()

await expect(() =>
worker.remote.testing(
Comlink.proxy({
...emptyTaskHandler,
getStorage: () => {
throw new Error('panic')
},
}),
'0x0000',
),
).rejects.toThrowError('panic')

// ensure the worker is still good
const slotDuration = await getAuraSlotDuration(getCode())
expect(slotDuration).eq(12000)
})
})
3 changes: 2 additions & 1 deletion packages/core/src/wasm-executor/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ export interface WasmExecutor {
},
callback?: JsCallback,
) => Promise<TaskResponse>
testing: (callback: JsCallback, key: any) => Promise<any>
}

const logger = defaultLogger.child({ name: 'executor' })

let __executor_worker: Promise<{ remote: Comlink.Remote<WasmExecutor>; terminate: () => Promise<void> }> | undefined
const getWorker = async () => {
export const getWorker = async () => {
if (__executor_worker) return __executor_worker

const isNode = typeof process !== 'undefined' && process?.versions?.node // true for node or bun
Expand Down
6 changes: 5 additions & 1 deletion packages/core/src/wasm-executor/node-wasm-executor.js
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ const runTask = async (task, callback) => {
return pkg.run_task(task, callback, process.env.RUST_LOG)
}

const wasmExecutor = { runTask, getRuntimeVersion, calculateStateRoot, createProof, decodeProof }
const testing = async (callback, key) => {
return pkg.testing(callback, key)
}

const wasmExecutor = { runTask, getRuntimeVersion, calculateStateRoot, createProof, decodeProof, testing }

Comlink.expose(wasmExecutor, nodeEndpoint(parentPort))