Skip to content

Commit

Permalink
Merge pull request #34 from thibault-cne/issue/32
Browse files Browse the repository at this point in the history
Fixed security issue for dynamic update
  • Loading branch information
thibault-cne authored Aug 2, 2024
2 parents 6143d8c + 275bac3 commit 35f6fac
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 112 deletions.
4 changes: 2 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use domain::net::server::middleware::mandatory::MandatoryMiddlewareSvc;
use domain::net::server::stream::StreamServer;
use tokio::net::{TcpListener, UdpSocket};

use crate::service::middleware::{MetricsMiddlewareSvc, Stats, TsigMiddlewareSvc};
use crate::service::middleware::{MetricsMiddlewareSvc, Rfc2136MiddlewareSvc, Stats};
use crate::service::Watcher;

mod config;
Expand Down Expand Up @@ -76,7 +76,7 @@ async fn main() {
let dnsr = Arc::new(dnsr);
let dnsr_svc = EdnsMiddlewareSvc::new(dnsr.clone());
let dnsr_svc = MandatoryMiddlewareSvc::new(dnsr_svc);
let dnsr_svc = TsigMiddlewareSvc::new(dnsr.clone(), dnsr_svc);
let dnsr_svc = Rfc2136MiddlewareSvc::new(dnsr.clone(), dnsr_svc);
let dnsr_svc = MetricsMiddlewareSvc::new(dnsr_svc, stats.clone());

let addr = "0.0.0.0:53";
Expand Down
4 changes: 2 additions & 2 deletions src/service/middleware/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mod metric;
mod tsig;
mod rfc2136;

pub use metric::{MetricsMiddlewareSvc, Stats};
pub use tsig::TsigMiddlewareSvc;
pub use rfc2136::Rfc2136MiddlewareSvc;
160 changes: 136 additions & 24 deletions src/service/middleware/tsig.rs → src/service/middleware/rfc2136.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,38 @@
use core::future::{ready, Ready};

use std::collections::HashMap;
use std::marker::PhantomData;
use std::sync::Arc;
use std::sync::{Arc, Mutex};

use bytes::Bytes;
use domain::base::iana::Rcode;
use domain::base::message_builder::AdditionalBuilder;
use domain::base::wire::Composer;
use domain::base::{Message, Name, Rtype, StreamTarget, ToName};
use domain::base::{Message, Name, ParsedName, Rtype, StreamTarget, ToName, Ttl};
use domain::dep::octseq::Octets;
use domain::net::server::message::Request;
use domain::net::server::middleware::stream::{MiddlewareStream, PostprocessingStream};
use domain::net::server::service::{Service, ServiceResult};
use domain::net::server::util::mk_builder_for_target;
use domain::rdata::tsig::Time48;
use domain::rdata::{AllRecordData, ZoneRecordData};
use domain::tsig::{Key, ServerSequence, ServerTransaction};
use domain::zonetree::Answer;
use domain::zonetree::types::{StoredRecord, StoredRecordData};
use domain::zonetree::{Answer, Rrset};
use futures::stream::Once;
use futures::FutureExt;

use crate::key::{DomainName, KeyStore, Keys};
use crate::service::handler::HandlerResult;

#[derive(Clone, Debug)]
pub struct TsigMiddlewareSvc<Octets, Svc> {
pub struct Rfc2136MiddlewareSvc<Octets, Svc> {
dnsr: Arc<crate::service::Dnsr>,
svc: Svc,
_octets: PhantomData<Octets>,
}

impl<RequestOctets, Svc> TsigMiddlewareSvc<RequestOctets, Svc>
impl<RequestOctets, Svc> Rfc2136MiddlewareSvc<RequestOctets, Svc>
where
RequestOctets: Octets + Send + Sync + Unpin + Clone,
Svc: Service<RequestOctets>,
Expand All @@ -42,18 +47,36 @@ where
}

fn postprocess_non_axfr(
keystore: &KeyStore,
keys: &Keys,
dnsr: Arc<crate::service::Dnsr>,
qname: &Name<Bytes>,
message: &mut Message<Vec<u8>>,
response: &mut AdditionalBuilder<StreamTarget<Svc::Target>>,
) -> Result<(), AdditionalBuilder<StreamTarget<<Svc as Service<RequestOctets>>::Target>>> {
match ServerTransaction::request::<KeyStore, Vec<u8>>(keystore, message, Time48::now()) {
let keystore = dnsr.keystore.read().unwrap();
let keys = &dnsr.config.keys;
let cloned_message = message.clone();
let bytes = cloned_message.as_slice();
let message_bytes = Message::from_octets(Bytes::copy_from_slice(bytes)).unwrap();

match ServerTransaction::request::<KeyStore, Vec<u8>>(&keystore, message, Time48::now()) {
Ok(None) => Ok(()),
Ok(Some(transaction)) if validate_key_scope(keys, transaction.key(), qname) => {
log::info!(target: "svc", "found tsig key for transaction");
transaction.answer(response, Time48::now()).unwrap();
Ok(())

match handle_update_query(dnsr.clone(), message_bytes) {
Ok(_) => {
log::info!(target: "update", "successfully updated the zone");
log::debug!("{:?}", dnsr.zones);
transaction.answer(response, Time48::now()).unwrap();
Ok(())
}
Err(e) => {
log::error!(target: "update", "error while updating the dnsr zones: {}", e);
let answer = Answer::new(Rcode::SERVFAIL);
let builder = mk_builder_for_target();
Err(answer.to_message(message, builder))
}
}
}
Ok(_) => {
log::error!(target: "tsig", "tsig used is not in the valid scope");
Expand All @@ -71,18 +94,36 @@ where
}

fn postprocess_axfr(
keystore: &KeyStore,
keys: &Keys,
dnsr: Arc<crate::service::Dnsr>,
qname: &Name<Bytes>,
message: &mut Message<Vec<u8>>,
response: &mut AdditionalBuilder<StreamTarget<Svc::Target>>,
) -> Result<(), AdditionalBuilder<StreamTarget<<Svc as Service<RequestOctets>>::Target>>> {
match ServerSequence::request::<KeyStore, Vec<u8>>(keystore, message, Time48::now()) {
let keystore = dnsr.keystore.read().unwrap();
let keys = &dnsr.config.keys;
let cloned_message = message.clone();
let bytes = cloned_message.as_slice();
let message_bytes = Message::from_octets(Bytes::copy_from_slice(bytes)).unwrap();

match ServerSequence::request::<KeyStore, Vec<u8>>(&keystore, message, Time48::now()) {
Ok(None) => Ok(()),
Ok(Some(mut sequence)) if validate_key_scope(keys, sequence.key(), qname) => {
log::info!(target: "svc", "found tsig key for transaction");
sequence.answer(response, Time48::now()).unwrap();
Ok(())

match handle_update_query(dnsr.clone(), message_bytes) {
Ok(_) => {
log::info!(target: "update", "successfully updated the zone");
log::debug!("{:?}", dnsr.zones);
sequence.answer(response, Time48::now()).unwrap();
Ok(())
}
Err(e) => {
log::error!(target: "update", "error while updating the dnsr zones: {}", e);
let answer = Answer::new(Rcode::SERVFAIL);
let builder = mk_builder_for_target();
Err(answer.to_message(message, builder))
}
}
}
Ok(_) => {
log::error!(target: "tsig", "tsig used is not in the valid scope");
Expand All @@ -104,8 +145,6 @@ where
request: &Request<RequestOctets>,
response: &mut AdditionalBuilder<StreamTarget<Svc::Target>>,
) -> Result<(), AdditionalBuilder<StreamTarget<<Svc as Service<RequestOctets>>::Target>>> {
let keystore = dnsr.keystore.read().unwrap();
let keys = &dnsr.config.keys;
let bytes = request.message().as_slice();
let mut message = Message::from_octets(bytes.to_vec()).unwrap();
let qname = request
Expand All @@ -122,9 +161,9 @@ where
.map(|q| q.qtype() == Rtype::AXFR),
Ok(true)
) {
Self::postprocess_non_axfr(&keystore, keys, &qname, &mut message, response)
Self::postprocess_non_axfr(dnsr, &qname, &mut message, response)
} else {
Self::postprocess_axfr(&keystore, keys, &qname, &mut message, response)
Self::postprocess_axfr(dnsr, &qname, &mut message, response)
}
}

Expand All @@ -144,7 +183,7 @@ where
}
}

impl<RequestOctets, Svc> Service<RequestOctets> for TsigMiddlewareSvc<RequestOctets, Svc>
impl<RequestOctets, Svc> Service<RequestOctets> for Rfc2136MiddlewareSvc<RequestOctets, Svc>
where
RequestOctets: Octets + Send + Sync + 'static + Unpin + Clone,
Svc: Service<RequestOctets>,
Expand Down Expand Up @@ -175,11 +214,84 @@ where
}

fn validate_key_scope(keys: &Keys, key: &Key, dname: &Name<Bytes>) -> bool {
let key_file = dbg!(key.name().into());
let key_file = key.name().into();
let dname = Into::<DomainName>::into(dname).strip_prefix();

dbg!(keys
.get(&key_file)
keys.get(&key_file)
.map(|d| d.contains_key(&dname))
.unwrap_or(false))
.unwrap_or(false)
}

fn handle_update_query(
dnsr: Arc<crate::service::Dnsr>,
message: Message<Bytes>,
) -> HandlerResult<()> {
log::debug!("handle_update_query");
let authority = message.authority()?;
let mut records: HashMap<(Rtype, Ttl), Vec<StoredRecordData>> = HashMap::new();

for a in authority {
let a = a?.to_record::<AllRecordData<Bytes, ParsedName<Bytes>>>()?;

if let Some(record) = a {
let data: ZoneRecordData<Bytes, Name<Bytes>> = match record.data() {
AllRecordData::Txt(txt) => txt.clone().into(),
_ => unimplemented!(),
};

let record = StoredRecord::new(
record.owner().to_bytes(),
record.class(),
record.ttl(),
data,
);
records
.entry((record.rtype(), record.ttl()))
.or_default()
.push(record.data().to_owned());
}
}

let question = message.sole_question().unwrap();
let qtype = question.qtype();
let qname = question.qname().clone();
let records = Arc::new(Mutex::new(records));
let cloned_records = records.clone();

let op = Box::new(move |owner: Name<Bytes>, rrset: &Rrset| {
if rrset.rtype() == qtype && owner == qname {
let mut records = cloned_records.lock().unwrap();
records
.entry((rrset.rtype(), rrset.ttl()))
.or_default()
.extend(rrset.data().to_vec());
}
});

dnsr.zones.find_zone_walk(question.qname(), |zone| {
if let Some(zone) = zone {
zone.walk(op);
}
});

let mutex = Arc::try_unwrap(records).unwrap();
let records = mutex.into_inner().unwrap();

// TODO: handle this lot of unwraps
if let Some(zone) = dnsr.zones.find_zone(&question.qname()) {
let mut writer = zone.write().now_or_never().unwrap();
let open = writer.open().now_or_never().unwrap().unwrap();

records.into_iter().for_each(|((rtype, ttl), data)| {
let mut rset = Rrset::new(rtype, ttl);
data.into_iter().for_each(|data| rset.push_data(data));
open.update_rrset(rset.into_shared())
.now_or_never()
.unwrap()
.unwrap();
});
writer.commit().now_or_never().unwrap().unwrap();
}

Ok(())
}
84 changes: 0 additions & 84 deletions src/service/mod.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,26 @@
use core::future::{ready, Future};

use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::RwLock;

use bytes::Bytes;
use domain::base::iana::Opcode;
use domain::base::iana::{Class, Rcode};
use domain::base::message_builder::AdditionalBuilder;
use domain::base::Message;
use domain::base::Name;
use domain::base::ParsedName;
use domain::base::Ttl;
use domain::base::{Rtype, ToName};
use domain::dep::octseq::OctetsBuilder;
use domain::net::server::message::Request;
use domain::net::server::service::CallResult;
use domain::net::server::service::{Service, ServiceResult};
use domain::net::server::util::mk_builder_for_target;
use domain::rdata::AllRecordData;
use domain::rdata::ZoneRecordData;
use domain::zonetree::types::StoredRecord;
use domain::zonetree::types::StoredRecordData;
use domain::zonetree::Rrset;
use domain::zonetree::{Answer, ReadableZone, Zone};
use futures::channel::mpsc::unbounded;
use futures::channel::mpsc::UnboundedSender;
use futures::stream::{once, Stream};
use futures::FutureExt;

use crate::config::Config;
use crate::error::Error;
Expand Down Expand Up @@ -86,11 +77,6 @@ impl Service<Vec<u8>> for Dnsr {

impl HandleDNS for Dnsr {
fn handle_non_axfr(&self, request: Request<Vec<u8>>) -> HandlerResult<CallResult<Vec<u8>>> {
let bytes = request.message().as_slice();
let message = Message::from_octets(Bytes::copy_from_slice(bytes)).unwrap();

handle_update_query(self, message)?;

let answer = {
let question = request.message().sole_question().unwrap();
self.zones
Expand Down Expand Up @@ -227,76 +213,6 @@ impl HandleDNS for Dnsr {
}
}

fn handle_update_query(dnsr: &Dnsr, message: Message<Bytes>) -> HandlerResult<()> {
let authority = message.authority()?;
let mut records: HashMap<(Rtype, Ttl), Vec<StoredRecordData>> = HashMap::new();

for a in authority {
let a = a?.to_record::<AllRecordData<Bytes, ParsedName<Bytes>>>()?;

if let Some(record) = a {
let data: ZoneRecordData<Bytes, Name<Bytes>> = match record.data() {
AllRecordData::Txt(txt) => txt.clone().into(),
_ => unimplemented!(),
};

let record = StoredRecord::new(
record.owner().to_bytes(),
record.class(),
record.ttl(),
data,
);
records
.entry((record.rtype(), record.ttl()))
.or_default()
.push(record.data().to_owned());
}
}

let question = message.sole_question().unwrap();
let qtype = question.qtype();
let qname = question.qname().clone();
let records = Arc::new(Mutex::new(records));
let cloned_records = records.clone();

let op = Box::new(move |owner: Name<Bytes>, rrset: &Rrset| {
if rrset.rtype() == qtype && owner == qname {
let mut records = cloned_records.lock().unwrap();
records
.entry((rrset.rtype(), rrset.ttl()))
.or_default()
.extend(rrset.data().to_vec());
}
});

dnsr.zones.find_zone_walk(question.qname(), |zone| {
if let Some(zone) = zone {
zone.walk(op);
}
});

let mutex = Arc::try_unwrap(records).unwrap();
let records = mutex.into_inner().unwrap();

// TODO: handle this lot of unwraps
if let Some(zone) = dnsr.zones.find_zone(&question.qname()) {
let mut writer = zone.write().now_or_never().unwrap();
let open = writer.open().now_or_never().unwrap().unwrap();

records.into_iter().for_each(|((rtype, ttl), data)| {
let mut rset = Rrset::new(rtype, ttl);
data.into_iter().for_each(|data| rset.push_data(data));
open.update_rrset(rset.into_shared())
.now_or_never()
.unwrap()
.unwrap();
});
writer.commit().now_or_never().unwrap().unwrap();
}

Ok(())
}

fn add_to_stream(
answer: Answer,
msg: &Message<Vec<u8>>,
Expand Down

0 comments on commit 35f6fac

Please sign in to comment.