diff --git a/src/main.rs b/src/main.rs index c807609..b732a4a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -28,7 +28,7 @@ use domain::net::server::middleware::mandatory::MandatoryMiddlewareSvc; use domain::net::server::stream::StreamServer; use tokio::net::{TcpListener, UdpSocket}; -use crate::service::middleware::{MetricsMiddlewareSvc, Stats, TsigMiddlewareSvc}; +use crate::service::middleware::{MetricsMiddlewareSvc, Rfc2136MiddlewareSvc, Stats}; use crate::service::Watcher; mod config; @@ -76,7 +76,7 @@ async fn main() { let dnsr = Arc::new(dnsr); let dnsr_svc = EdnsMiddlewareSvc::new(dnsr.clone()); let dnsr_svc = MandatoryMiddlewareSvc::new(dnsr_svc); - let dnsr_svc = TsigMiddlewareSvc::new(dnsr.clone(), dnsr_svc); + let dnsr_svc = Rfc2136MiddlewareSvc::new(dnsr.clone(), dnsr_svc); let dnsr_svc = MetricsMiddlewareSvc::new(dnsr_svc, stats.clone()); let addr = "0.0.0.0:53"; diff --git a/src/service/middleware/mod.rs b/src/service/middleware/mod.rs index 98edeef..d0482d8 100644 --- a/src/service/middleware/mod.rs +++ b/src/service/middleware/mod.rs @@ -1,5 +1,5 @@ mod metric; -mod tsig; +mod rfc2136; pub use metric::{MetricsMiddlewareSvc, Stats}; -pub use tsig::TsigMiddlewareSvc; +pub use rfc2136::Rfc2136MiddlewareSvc; diff --git a/src/service/middleware/tsig.rs b/src/service/middleware/rfc2136.rs similarity index 51% rename from src/service/middleware/tsig.rs rename to src/service/middleware/rfc2136.rs index 4e2f9b4..45e48bd 100644 --- a/src/service/middleware/tsig.rs +++ b/src/service/middleware/rfc2136.rs @@ -1,33 +1,38 @@ use core::future::{ready, Ready}; +use std::collections::HashMap; use std::marker::PhantomData; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use bytes::Bytes; use domain::base::iana::Rcode; use domain::base::message_builder::AdditionalBuilder; use domain::base::wire::Composer; -use domain::base::{Message, Name, Rtype, StreamTarget, ToName}; +use domain::base::{Message, Name, ParsedName, Rtype, StreamTarget, ToName, Ttl}; use domain::dep::octseq::Octets; use domain::net::server::message::Request; use domain::net::server::middleware::stream::{MiddlewareStream, PostprocessingStream}; use domain::net::server::service::{Service, ServiceResult}; use domain::net::server::util::mk_builder_for_target; use domain::rdata::tsig::Time48; +use domain::rdata::{AllRecordData, ZoneRecordData}; use domain::tsig::{Key, ServerSequence, ServerTransaction}; -use domain::zonetree::Answer; +use domain::zonetree::types::{StoredRecord, StoredRecordData}; +use domain::zonetree::{Answer, Rrset}; use futures::stream::Once; +use futures::FutureExt; use crate::key::{DomainName, KeyStore, Keys}; +use crate::service::handler::HandlerResult; #[derive(Clone, Debug)] -pub struct TsigMiddlewareSvc { +pub struct Rfc2136MiddlewareSvc { dnsr: Arc, svc: Svc, _octets: PhantomData, } -impl TsigMiddlewareSvc +impl Rfc2136MiddlewareSvc where RequestOctets: Octets + Send + Sync + Unpin + Clone, Svc: Service, @@ -42,18 +47,36 @@ where } fn postprocess_non_axfr( - keystore: &KeyStore, - keys: &Keys, + dnsr: Arc, qname: &Name, message: &mut Message>, response: &mut AdditionalBuilder>, ) -> Result<(), AdditionalBuilder>::Target>>> { - match ServerTransaction::request::>(keystore, message, Time48::now()) { + let keystore = dnsr.keystore.read().unwrap(); + let keys = &dnsr.config.keys; + let cloned_message = message.clone(); + let bytes = cloned_message.as_slice(); + let message_bytes = Message::from_octets(Bytes::copy_from_slice(bytes)).unwrap(); + + match ServerTransaction::request::>(&keystore, message, Time48::now()) { Ok(None) => Ok(()), Ok(Some(transaction)) if validate_key_scope(keys, transaction.key(), qname) => { log::info!(target: "svc", "found tsig key for transaction"); - transaction.answer(response, Time48::now()).unwrap(); - Ok(()) + + match handle_update_query(dnsr.clone(), message_bytes) { + Ok(_) => { + log::info!(target: "update", "successfully updated the zone"); + log::debug!("{:?}", dnsr.zones); + transaction.answer(response, Time48::now()).unwrap(); + Ok(()) + } + Err(e) => { + log::error!(target: "update", "error while updating the dnsr zones: {}", e); + let answer = Answer::new(Rcode::SERVFAIL); + let builder = mk_builder_for_target(); + Err(answer.to_message(message, builder)) + } + } } Ok(_) => { log::error!(target: "tsig", "tsig used is not in the valid scope"); @@ -71,18 +94,36 @@ where } fn postprocess_axfr( - keystore: &KeyStore, - keys: &Keys, + dnsr: Arc, qname: &Name, message: &mut Message>, response: &mut AdditionalBuilder>, ) -> Result<(), AdditionalBuilder>::Target>>> { - match ServerSequence::request::>(keystore, message, Time48::now()) { + let keystore = dnsr.keystore.read().unwrap(); + let keys = &dnsr.config.keys; + let cloned_message = message.clone(); + let bytes = cloned_message.as_slice(); + let message_bytes = Message::from_octets(Bytes::copy_from_slice(bytes)).unwrap(); + + match ServerSequence::request::>(&keystore, message, Time48::now()) { Ok(None) => Ok(()), Ok(Some(mut sequence)) if validate_key_scope(keys, sequence.key(), qname) => { log::info!(target: "svc", "found tsig key for transaction"); - sequence.answer(response, Time48::now()).unwrap(); - Ok(()) + + match handle_update_query(dnsr.clone(), message_bytes) { + Ok(_) => { + log::info!(target: "update", "successfully updated the zone"); + log::debug!("{:?}", dnsr.zones); + sequence.answer(response, Time48::now()).unwrap(); + Ok(()) + } + Err(e) => { + log::error!(target: "update", "error while updating the dnsr zones: {}", e); + let answer = Answer::new(Rcode::SERVFAIL); + let builder = mk_builder_for_target(); + Err(answer.to_message(message, builder)) + } + } } Ok(_) => { log::error!(target: "tsig", "tsig used is not in the valid scope"); @@ -104,8 +145,6 @@ where request: &Request, response: &mut AdditionalBuilder>, ) -> Result<(), AdditionalBuilder>::Target>>> { - let keystore = dnsr.keystore.read().unwrap(); - let keys = &dnsr.config.keys; let bytes = request.message().as_slice(); let mut message = Message::from_octets(bytes.to_vec()).unwrap(); let qname = request @@ -122,9 +161,9 @@ where .map(|q| q.qtype() == Rtype::AXFR), Ok(true) ) { - Self::postprocess_non_axfr(&keystore, keys, &qname, &mut message, response) + Self::postprocess_non_axfr(dnsr, &qname, &mut message, response) } else { - Self::postprocess_axfr(&keystore, keys, &qname, &mut message, response) + Self::postprocess_axfr(dnsr, &qname, &mut message, response) } } @@ -144,7 +183,7 @@ where } } -impl Service for TsigMiddlewareSvc +impl Service for Rfc2136MiddlewareSvc where RequestOctets: Octets + Send + Sync + 'static + Unpin + Clone, Svc: Service, @@ -175,11 +214,84 @@ where } fn validate_key_scope(keys: &Keys, key: &Key, dname: &Name) -> bool { - let key_file = dbg!(key.name().into()); + let key_file = key.name().into(); let dname = Into::::into(dname).strip_prefix(); - dbg!(keys - .get(&key_file) + keys.get(&key_file) .map(|d| d.contains_key(&dname)) - .unwrap_or(false)) + .unwrap_or(false) +} + +fn handle_update_query( + dnsr: Arc, + message: Message, +) -> HandlerResult<()> { + log::debug!("handle_update_query"); + let authority = message.authority()?; + let mut records: HashMap<(Rtype, Ttl), Vec> = HashMap::new(); + + for a in authority { + let a = a?.to_record::>>()?; + + if let Some(record) = a { + let data: ZoneRecordData> = match record.data() { + AllRecordData::Txt(txt) => txt.clone().into(), + _ => unimplemented!(), + }; + + let record = StoredRecord::new( + record.owner().to_bytes(), + record.class(), + record.ttl(), + data, + ); + records + .entry((record.rtype(), record.ttl())) + .or_default() + .push(record.data().to_owned()); + } + } + + let question = message.sole_question().unwrap(); + let qtype = question.qtype(); + let qname = question.qname().clone(); + let records = Arc::new(Mutex::new(records)); + let cloned_records = records.clone(); + + let op = Box::new(move |owner: Name, rrset: &Rrset| { + if rrset.rtype() == qtype && owner == qname { + let mut records = cloned_records.lock().unwrap(); + records + .entry((rrset.rtype(), rrset.ttl())) + .or_default() + .extend(rrset.data().to_vec()); + } + }); + + dnsr.zones.find_zone_walk(question.qname(), |zone| { + if let Some(zone) = zone { + zone.walk(op); + } + }); + + let mutex = Arc::try_unwrap(records).unwrap(); + let records = mutex.into_inner().unwrap(); + + // TODO: handle this lot of unwraps + if let Some(zone) = dnsr.zones.find_zone(&question.qname()) { + let mut writer = zone.write().now_or_never().unwrap(); + let open = writer.open().now_or_never().unwrap().unwrap(); + + records.into_iter().for_each(|((rtype, ttl), data)| { + let mut rset = Rrset::new(rtype, ttl); + data.into_iter().for_each(|data| rset.push_data(data)); + open.update_rrset(rset.into_shared()) + .now_or_never() + .unwrap() + .unwrap(); + }); + writer.commit().now_or_never().unwrap().unwrap(); + } + + Ok(()) } diff --git a/src/service/mod.rs b/src/service/mod.rs index e034554..dbf4805 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -1,35 +1,26 @@ use core::future::{ready, Future}; -use std::collections::HashMap; use std::pin::Pin; use std::sync::Arc; use std::sync::Mutex; use std::sync::RwLock; -use bytes::Bytes; use domain::base::iana::Opcode; use domain::base::iana::{Class, Rcode}; use domain::base::message_builder::AdditionalBuilder; use domain::base::Message; use domain::base::Name; -use domain::base::ParsedName; -use domain::base::Ttl; use domain::base::{Rtype, ToName}; use domain::dep::octseq::OctetsBuilder; use domain::net::server::message::Request; use domain::net::server::service::CallResult; use domain::net::server::service::{Service, ServiceResult}; use domain::net::server::util::mk_builder_for_target; -use domain::rdata::AllRecordData; -use domain::rdata::ZoneRecordData; -use domain::zonetree::types::StoredRecord; -use domain::zonetree::types::StoredRecordData; use domain::zonetree::Rrset; use domain::zonetree::{Answer, ReadableZone, Zone}; use futures::channel::mpsc::unbounded; use futures::channel::mpsc::UnboundedSender; use futures::stream::{once, Stream}; -use futures::FutureExt; use crate::config::Config; use crate::error::Error; @@ -86,11 +77,6 @@ impl Service> for Dnsr { impl HandleDNS for Dnsr { fn handle_non_axfr(&self, request: Request>) -> HandlerResult>> { - let bytes = request.message().as_slice(); - let message = Message::from_octets(Bytes::copy_from_slice(bytes)).unwrap(); - - handle_update_query(self, message)?; - let answer = { let question = request.message().sole_question().unwrap(); self.zones @@ -227,76 +213,6 @@ impl HandleDNS for Dnsr { } } -fn handle_update_query(dnsr: &Dnsr, message: Message) -> HandlerResult<()> { - let authority = message.authority()?; - let mut records: HashMap<(Rtype, Ttl), Vec> = HashMap::new(); - - for a in authority { - let a = a?.to_record::>>()?; - - if let Some(record) = a { - let data: ZoneRecordData> = match record.data() { - AllRecordData::Txt(txt) => txt.clone().into(), - _ => unimplemented!(), - }; - - let record = StoredRecord::new( - record.owner().to_bytes(), - record.class(), - record.ttl(), - data, - ); - records - .entry((record.rtype(), record.ttl())) - .or_default() - .push(record.data().to_owned()); - } - } - - let question = message.sole_question().unwrap(); - let qtype = question.qtype(); - let qname = question.qname().clone(); - let records = Arc::new(Mutex::new(records)); - let cloned_records = records.clone(); - - let op = Box::new(move |owner: Name, rrset: &Rrset| { - if rrset.rtype() == qtype && owner == qname { - let mut records = cloned_records.lock().unwrap(); - records - .entry((rrset.rtype(), rrset.ttl())) - .or_default() - .extend(rrset.data().to_vec()); - } - }); - - dnsr.zones.find_zone_walk(question.qname(), |zone| { - if let Some(zone) = zone { - zone.walk(op); - } - }); - - let mutex = Arc::try_unwrap(records).unwrap(); - let records = mutex.into_inner().unwrap(); - - // TODO: handle this lot of unwraps - if let Some(zone) = dnsr.zones.find_zone(&question.qname()) { - let mut writer = zone.write().now_or_never().unwrap(); - let open = writer.open().now_or_never().unwrap().unwrap(); - - records.into_iter().for_each(|((rtype, ttl), data)| { - let mut rset = Rrset::new(rtype, ttl); - data.into_iter().for_each(|data| rset.push_data(data)); - open.update_rrset(rset.into_shared()) - .now_or_never() - .unwrap() - .unwrap(); - }); - writer.commit().now_or_never().unwrap().unwrap(); - } - - Ok(()) -} - fn add_to_stream( answer: Answer, msg: &Message>,