Skip to content

Commit

Permalink
Sync with syntaxdot main branch, adding biaffine parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk authored and Daniël de Kok committed Feb 2, 2021
1 parent 04b46bc commit 227788b
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 18 deletions.
21 changes: 11 additions & 10 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 6 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ ffi-support = "0.4"
lazy_static = "1"
prost = "0.6"
serde_yaml = "0.8"
syntaxdot = { version = "0.2", default-features = false }
syntaxdot-tch-ext = "0.2.0"
syntaxdot-tokenizers = "0.2.0"
syntaxdot-transformers = { version = "0.2", default-features = false }
syntaxdot = { git = "https://github.com/tensordot/syntaxdot.git", branch = "main" }
syntaxdot-encoders = { git = "https://github.com/tensordot/syntaxdot.git", branch = "main" }
syntaxdot-tch-ext = { git = "https://github.com/tensordot/syntaxdot.git", branch = "main" }
syntaxdot-tokenizers = { git = "https://github.com/tensordot/syntaxdot.git", branch = "main" }
syntaxdot-transformers = { git = "https://github.com/tensordot/syntaxdot.git", branch = "main" }
tch = "0.3"
thiserror = "1"
toml = "0.5"
Expand All @@ -28,4 +29,4 @@ prost-build = "0.6"
pretty_assertions = "0.6"

[features]
model-tests = []
model-tests = []
31 changes: 29 additions & 2 deletions src/annotator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ use std::ops::Deref;
use std::path::Path;

use conllu::graph::Sentence;
use syntaxdot::config::{Config, PretrainConfig, TomlRead};
use syntaxdot::config::{BiaffineParserConfig, Config, PretrainConfig, TomlRead};
use syntaxdot::encoders::Encoders;
use syntaxdot::model::bert::BertModel;
use syntaxdot::tagger::Tagger;
use syntaxdot_encoders::dependency::ImmutableDependencyEncoder;
use syntaxdot_tch_ext::RootExt;
use syntaxdot_tokenizers::{SentenceWithPieces, Tokenize};
use tch::nn::VarStore;
Expand Down Expand Up @@ -58,6 +59,11 @@ impl Annotator {
let mut config = Config::from_toml_read(r)?;
config.relativize_paths(config_path)?;

let biaffine_decoder = config
.biaffine
.as_ref()
.map(|config| load_biaffine_decoder(config))
.transpose()?;
let encoders = load_encoders(&config)?;
let tokenizer = load_tokenizer(&config)?;
let pretrain_config = load_pretrain_config(&config)?;
Expand All @@ -67,6 +73,11 @@ impl Annotator {
let model = BertModel::new(
vs.root_ext(|_| 0),
&pretrain_config,
config.biaffine.as_ref(),
biaffine_decoder
.as_ref()
.map(ImmutableDependencyEncoder::n_relations)
.unwrap_or(0),
&encoders,
0.0,
config.model.position_embeddings.clone(),
Expand All @@ -76,7 +87,7 @@ impl Annotator {

vs.freeze();

let tagger = Tagger::new(device, model, encoders);
let tagger = Tagger::new(device, model, biaffine_decoder, encoders);

Ok(Annotator {
tagger: TaggerWrap(tagger),
Expand Down Expand Up @@ -111,6 +122,22 @@ pub fn load_pretrain_config(config: &Config) -> Result<PretrainConfig, Annotator
Ok(config.model.pretrain_config()?)
}

fn load_biaffine_decoder(
config: &BiaffineParserConfig,
) -> Result<ImmutableDependencyEncoder, AnnotatorError> {
let f = File::open(&config.labels).map_err(|err| {
AnnotatorError::IO(
format!("Cannot open biaffine label file: {}", config.labels),
err,
)
})?;

let encoder: ImmutableDependencyEncoder = serde_yaml::from_reader(&f)
.map_err(|err| AnnotatorError::LoadEncoders(config.labels.clone(), err))?;

Ok(encoder)
}

fn load_encoders(config: &Config) -> Result<Encoders, AnnotatorError> {
let f = File::open(&config.labeler.labels).map_err(|err| {
AnnotatorError::IO(
Expand Down
2 changes: 1 addition & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::io;

use ffi_support::{ErrorCode, ExternError};
use syntaxdot::error::SyntaxDotError;
use syntaxdot_transformers::error::TransformerError;
use syntaxdot_transformers::TransformerError;
use thiserror::Error;

pub mod error_codes {
Expand Down

0 comments on commit 227788b

Please sign in to comment.