From 227788bc2472819ecd13f36abdcab3d6f3f9b107 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 29 Jan 2021 11:03:28 +0100 Subject: [PATCH] Sync with syntaxdot main branch, adding biaffine parsing --- Cargo.lock | 21 +++++++++++---------- Cargo.toml | 11 ++++++----- src/annotator.rs | 31 +++++++++++++++++++++++++++++-- src/error.rs | 2 +- 4 files changed, 47 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 072147a..6c7fd2a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -875,10 +875,10 @@ dependencies = [ [[package]] name = "syntaxdot" version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3505bcde7418572353cd512fe20d7cd7591ebcf0a3c722dffcdd6fb5884ea592" +source = "git+https://github.com/tensordot/syntaxdot.git?branch=main#6ccbeaa1586a042c680d70f15d6ab415662b57e4" dependencies = [ "conllu", + "log", "ndarray", "numberer", "ordered-float", @@ -898,15 +898,18 @@ dependencies = [ [[package]] name = "syntaxdot-encoders" version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ce374616bf198ae18cb560f16004c9674aaad5c09e9feb4831b507957de04ef" +source = "git+https://github.com/tensordot/syntaxdot.git?branch=main#6ccbeaa1586a042c680d70f15d6ab415662b57e4" dependencies = [ "conllu", + "itertools 0.9.0", "lazy_static", + "ndarray", "numberer", "ohnomore", "ordered-float", "petgraph", + "rand", + "rand_xorshift", "seqalign", "serde", "serde_derive", @@ -925,6 +928,7 @@ dependencies = [ "prost-build", "serde_yaml", "syntaxdot", + "syntaxdot-encoders", "syntaxdot-tch-ext", "syntaxdot-tokenizers", "syntaxdot-transformers", @@ -936,8 +940,7 @@ dependencies = [ [[package]] name = "syntaxdot-tch-ext" version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a70949fc02ebbc8edb2f8c2cef3d5a114ed9d9013559d823bfa468d650ecf6f" +source = "git+https://github.com/tensordot/syntaxdot.git?branch=main#6ccbeaa1586a042c680d70f15d6ab415662b57e4" dependencies = [ "itertools 0.9.0", "tch", @@ -946,8 +949,7 @@ dependencies = [ [[package]] name = "syntaxdot-tokenizers" version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24e321044ca972b3cce4fd7624727da896c7425ae73acb6071ecd1da3cc965b9" +source = "git+https://github.com/tensordot/syntaxdot.git?branch=main#6ccbeaa1586a042c680d70f15d6ab415662b57e4" dependencies = [ "conllu", "ndarray", @@ -959,8 +961,7 @@ dependencies = [ [[package]] name = "syntaxdot-transformers" version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c882901a13e3c064dc04c23e1868538817caa903b28abe97d954e2cabf3fb67" +source = "git+https://github.com/tensordot/syntaxdot.git?branch=main#6ccbeaa1586a042c680d70f15d6ab415662b57e4" dependencies = [ "serde", "syntaxdot-tch-ext", diff --git a/Cargo.toml b/Cargo.toml index 7469c0d..4de7857 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,10 +13,11 @@ ffi-support = "0.4" lazy_static = "1" prost = "0.6" serde_yaml = "0.8" -syntaxdot = { version = "0.2", default-features = false } -syntaxdot-tch-ext = "0.2.0" -syntaxdot-tokenizers = "0.2.0" -syntaxdot-transformers = { version = "0.2", default-features = false } +syntaxdot = { git = "https://github.com/tensordot/syntaxdot.git", branch = "main" } +syntaxdot-encoders = { git = "https://github.com/tensordot/syntaxdot.git", branch = "main" } +syntaxdot-tch-ext = { git = "https://github.com/tensordot/syntaxdot.git", branch = "main" } +syntaxdot-tokenizers = { git = "https://github.com/tensordot/syntaxdot.git", branch = "main" } +syntaxdot-transformers = { git = "https://github.com/tensordot/syntaxdot.git", branch = "main" } tch = "0.3" thiserror = "1" toml = "0.5" @@ -28,4 +29,4 @@ prost-build = "0.6" pretty_assertions = "0.6" [features] -model-tests = [] +model-tests = [] \ No newline at end of file diff --git a/src/annotator.rs b/src/annotator.rs index 9734f2a..dfbb81c 100644 --- a/src/annotator.rs +++ b/src/annotator.rs @@ -4,10 +4,11 @@ use std::ops::Deref; use std::path::Path; use conllu::graph::Sentence; -use syntaxdot::config::{Config, PretrainConfig, TomlRead}; +use syntaxdot::config::{BiaffineParserConfig, Config, PretrainConfig, TomlRead}; use syntaxdot::encoders::Encoders; use syntaxdot::model::bert::BertModel; use syntaxdot::tagger::Tagger; +use syntaxdot_encoders::dependency::ImmutableDependencyEncoder; use syntaxdot_tch_ext::RootExt; use syntaxdot_tokenizers::{SentenceWithPieces, Tokenize}; use tch::nn::VarStore; @@ -58,6 +59,11 @@ impl Annotator { let mut config = Config::from_toml_read(r)?; config.relativize_paths(config_path)?; + let biaffine_decoder = config + .biaffine + .as_ref() + .map(|config| load_biaffine_decoder(config)) + .transpose()?; let encoders = load_encoders(&config)?; let tokenizer = load_tokenizer(&config)?; let pretrain_config = load_pretrain_config(&config)?; @@ -67,6 +73,11 @@ impl Annotator { let model = BertModel::new( vs.root_ext(|_| 0), &pretrain_config, + config.biaffine.as_ref(), + biaffine_decoder + .as_ref() + .map(ImmutableDependencyEncoder::n_relations) + .unwrap_or(0), &encoders, 0.0, config.model.position_embeddings.clone(), @@ -76,7 +87,7 @@ impl Annotator { vs.freeze(); - let tagger = Tagger::new(device, model, encoders); + let tagger = Tagger::new(device, model, biaffine_decoder, encoders); Ok(Annotator { tagger: TaggerWrap(tagger), @@ -111,6 +122,22 @@ pub fn load_pretrain_config(config: &Config) -> Result Result { + let f = File::open(&config.labels).map_err(|err| { + AnnotatorError::IO( + format!("Cannot open biaffine label file: {}", config.labels), + err, + ) + })?; + + let encoder: ImmutableDependencyEncoder = serde_yaml::from_reader(&f) + .map_err(|err| AnnotatorError::LoadEncoders(config.labels.clone(), err))?; + + Ok(encoder) +} + fn load_encoders(config: &Config) -> Result { let f = File::open(&config.labeler.labels).map_err(|err| { AnnotatorError::IO( diff --git a/src/error.rs b/src/error.rs index af14b2f..93e205d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,7 +2,7 @@ use std::io; use ffi_support::{ErrorCode, ExternError}; use syntaxdot::error::SyntaxDotError; -use syntaxdot_transformers::error::TransformerError; +use syntaxdot_transformers::TransformerError; use thiserror::Error; pub mod error_codes {