Skip to content

Commit

Permalink
Merge pull request #24 from thibault-cne/issue/23
Browse files Browse the repository at this point in the history
Fix tsig support issue
  • Loading branch information
thibault-cne authored Jul 31, 2024
2 parents bf142de + 5b09ee6 commit 24588be
Show file tree
Hide file tree
Showing 9 changed files with 417 additions and 168 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ domain = { version = "0.10.1", features = [
"net",
"unstable-server-transport",
"unstable-zonetree",
"tsig",
] }
log = { version = "0.4.22", features = ["std"] }
notify = { version = "6.1.1" }
Expand Down
80 changes: 80 additions & 0 deletions src/dns/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use std::sync::{Arc, RwLock};

use domain::base::iana::{Class, Rcode};
use domain::base::ToName;

use domain::zonetree::{Answer, ReadableZone, Zone};

use crate::config::Config;
use crate::error::Error;
use crate::key::KeyStore;
use crate::zone::ZoneTree;

pub use service::dns;

mod service;

type Zones = Arc<RwLock<ZoneTree>>;

pub struct State {
pub config: Config,
pub zones: Zones,
pub keystore: Arc<RwLock<KeyStore>>,
}

impl State {
pub fn config(&self) -> &Config {
&self.config
}

fn find_zone<N, F>(&self, qname: &N, class: Class, f: F) -> Answer
where
N: ToName,
F: FnOnce(Option<Box<dyn ReadableZone>>) -> Answer,
{
if class != Class::IN {
return Answer::new(Rcode::NXDOMAIN);
}

let zones = self.zones.read().unwrap();
f(zones.find_zone(qname).map(|z| z.read()))
}

pub fn insert_zone(&self, zone: Zone) -> Result<(), Error> {
log::info!(target: "zone_change", "adding zone {}", zone.apex_name());
let mut zones = self.zones.write().unwrap();
zones.insert_zone(zone)
}

pub fn remove_zone<N>(&self, name: &N, class: Class) -> Result<(), Error>
where
N: ToName,
{
log::info!(target: "zone_change", "removing zone {} {}", name.to_bytes(), class);

let mut zones = self.zones.write().unwrap();

for z in zones.iter_zones() {
log::debug!(target: "zone", "zone {:?}", z);
}

zones.remove_zone(name)?;

for z in zones.iter_zones() {
log::debug!(target: "zone", "zone {}", z.apex_name());
}

Ok(())
}
}

impl From<Config> for State {
fn from(config: Config) -> Self {
let zones = Arc::new(RwLock::new(ZoneTree::new()));
State {
config,
zones,
keystore: KeyStore::new_shared(),
}
}
}
174 changes: 91 additions & 83 deletions src/dns.rs → src/dns/service.rs
Original file line number Diff line number Diff line change
@@ -1,78 +1,20 @@
use std::future::{ready, Future};
use std::sync::{Arc, Mutex, RwLock};
use std::sync::{Arc, Mutex};

use domain::base::iana::{Class, Opcode, Rcode};
use domain::base::message_builder::AdditionalBuilder;
use domain::base::{Message, Name, Rtype, ToName};
use domain::net::server::message::Request;
use domain::net::server::service::{CallResult, ServiceError, Transaction, TransactionStream};
use domain::net::server::util::mk_builder_for_target;
use domain::zonetree::{Answer, ReadableZone, Rrset, Zone};
use domain::rdata::tsig::Time48;
use domain::tsig::{Key, ServerSequence, ServerTransaction};
use domain::zonetree::{Answer, Rrset};
use octseq::OctetsBuilder;

use crate::config::Config;
use crate::error::Error;
use crate::zone::ZoneTree;
use crate::key::KeyStore;

type Zones = Arc<RwLock<ZoneTree>>;

pub struct State {
config: Config,
zones: Zones,
}

impl State {
pub fn config(&self) -> &Config {
&self.config
}

fn find_zone<N, F>(&self, qname: &N, class: Class, f: F) -> Answer
where
N: ToName,
F: FnOnce(Option<Box<dyn ReadableZone>>) -> Answer,
{
if class != Class::IN {
return Answer::new(Rcode::NXDOMAIN);
}

let zones = self.zones.read().unwrap();
f(zones.find_zone(qname).map(|z| z.read()))
}

pub fn insert_zone(&self, zone: Zone) -> Result<(), Error> {
log::info!(target: "zone_change", "adding zone {}", zone.apex_name());
let mut zones = self.zones.write().unwrap();
zones.insert_zone(zone)
}

pub fn remove_zone<N>(&self, name: &N, class: Class) -> Result<(), Error>
where
N: ToName,
{
log::info!(target: "zone_change", "removing zone {} {}", name.to_bytes(), class);

let mut zones = self.zones.write().unwrap();

for z in zones.iter_zones() {
log::debug!(target: "zone", "zone {:?}", z);
}

zones.remove_zone(name)?;

for z in zones.iter_zones() {
log::debug!(target: "zone", "zone {}", z.apex_name());
}

Ok(())
}
}

impl From<Config> for State {
fn from(config: Config) -> Self {
let zones = Arc::new(RwLock::new(ZoneTree::new()));
State { config, zones }
}
}
use super::State;

pub fn dns(
request: Request<Vec<u8>>,
Expand All @@ -98,18 +40,31 @@ async fn handle_non_axfr_request(
request: Request<Vec<u8>>,
state: Arc<State>,
) -> Result<CallResult<Vec<u8>>, ServiceError> {
let question = request.message().sole_question().unwrap();
let answer = state.find_zone(question.qname(), question.qclass(), |zone| match zone {
Some(zone) => {
let qname = question.qname().to_bytes();
let qtype = question.qtype();
zone.query(qname, qtype).unwrap()
}
None => Answer::new(Rcode::NXDOMAIN),
});
let answer = {
let question = request.message().sole_question().unwrap();
state.find_zone(question.qname(), question.qclass(), |zone| match zone {
Some(zone) => {
let qname = question.qname().to_bytes();
let qtype = question.qtype();
zone.query(qname, qtype).unwrap()
}
None => Answer::new(Rcode::NXDOMAIN),
})
};

let builder = mk_builder_for_target();
let additional = answer.to_message(request.message(), builder);
let mut additional = answer.to_message(request.message(), builder);

let keystore = state.keystore.read().unwrap();
let mut message = request.message().clone();
let message = Arc::make_mut(&mut message);

match ServerTransaction::<Arc<Key>>::request::<KeyStore, _>(&keystore, message, Time48::now()) {
Ok(None) => (),
Ok(Some(transaction)) => transaction.answer(&mut additional, Time48::now()).unwrap(),
_ => (),
}

Ok(CallResult::new(additional))
}

Expand All @@ -118,13 +73,35 @@ async fn handle_axfr_request(
state: Arc<State>,
) -> TransactionStream<Result<CallResult<Vec<u8>>, ServiceError>> {
let mut stream = TransactionStream::default();
let keystore = state.keystore.read().unwrap();
let mut message = request.message().clone();
let message = Arc::make_mut(&mut message);

let mut server_sequence =
match ServerSequence::<Arc<Key>>::request::<KeyStore, _>(&keystore, message, Time48::now())
{
Ok(sequence) => sequence,
_ => return stream,
};

let request = Request::new(
request.client_addr(),
request.received_at(),
message.to_owned(),
request.transport_ctx().to_owned(),
);

// Look up the zone for the queried name.
let question = request.message().sole_question().unwrap();

if question.qclass() == Class::IN {
let answer = Answer::new(Rcode::NXDOMAIN);
add_to_stream(answer, request.message(), &mut stream);
add_to_stream(
server_sequence.as_mut(),
answer,
request.message(),
&mut stream,
);
return stream;
}

Expand All @@ -134,7 +111,12 @@ async fn handle_axfr_request(
// If not found, return an NXDOMAIN error response.
let Some(zone) = zone else {
let answer = Answer::new(Rcode::NXDOMAIN);
add_to_stream(answer, request.message(), &mut stream);
add_to_stream(
server_sequence.as_mut(),
answer,
request.message(),
&mut stream,
);
return stream;
};

Expand All @@ -158,12 +140,22 @@ async fn handle_axfr_request(
let qname = question.qname().to_bytes();
let Ok(soa_answer) = zone.query(qname, Rtype::SOA) else {
let answer = Answer::new(Rcode::SERVFAIL);
add_to_stream(answer, request.message(), &mut stream);
add_to_stream(
server_sequence.as_mut(),
answer,
request.message(),
&mut stream,
);
return stream;
};

// Push the begin SOA response message into the stream
add_to_stream(soa_answer.clone(), request.message(), &mut stream);
add_to_stream(
server_sequence.as_mut(),
soa_answer.clone(),
request.message(),
&mut stream,
);

// "The AXFR protocol treats the zone contents as an unordered
// collection (or to use the mathematical term, a "set") of
Expand Down Expand Up @@ -195,6 +187,8 @@ async fn handle_axfr_request(
let stream = Arc::new(Mutex::new(stream));
let cloned_stream = stream.clone();
let cloned_msg = request.message().clone();
let server_sequence = Arc::new(Mutex::new(server_sequence));
let cloned_server_sequence = server_sequence.clone();

let op = Box::new(move |owner: Name<_>, rrset: &Rrset| {
if rrset.rtype() != Rtype::SOA {
Expand All @@ -206,38 +200,52 @@ async fn handle_axfr_request(

let additional = answer.additional();
let mut stream = cloned_stream.lock().unwrap();
add_additional_to_stream(additional, &cloned_msg, &mut stream);
let mut server_sequence = cloned_server_sequence.lock().unwrap();
add_additional_to_stream(
server_sequence.as_mut(),
additional,
&cloned_msg,
&mut stream,
);
}
});
zone.walk(op);

let mutex = Arc::try_unwrap(stream).unwrap();
let mut stream = mutex.into_inner().unwrap();
let mutex = Arc::try_unwrap(server_sequence).unwrap();
let mut server_sequence = mutex.into_inner().unwrap();

// Push the end SOA response message into the stream
add_to_stream(soa_answer, request.message(), &mut stream);
add_to_stream(
server_sequence.as_mut(),
soa_answer,
request.message(),
&mut stream,
);

stream
}

#[allow(clippy::type_complexity)]
fn add_to_stream(
sequence: Option<&mut ServerSequence<Arc<Key>>>,
answer: Answer,
msg: &Message<Vec<u8>>,
stream: &mut TransactionStream<Result<CallResult<Vec<u8>>, ServiceError>>,
) {
let builder = mk_builder_for_target();
let additional = answer.to_message(msg, builder);
add_additional_to_stream(additional, msg, stream);
add_additional_to_stream(sequence, additional, msg, stream);
}

#[allow(clippy::type_complexity)]
fn add_additional_to_stream(
sequence: Option<&mut ServerSequence<Arc<Key>>>,
mut additional: AdditionalBuilder<domain::base::StreamTarget<Vec<u8>>>,
msg: &Message<Vec<u8>>,
stream: &mut TransactionStream<Result<CallResult<Vec<u8>>, ServiceError>>,
) {
set_axfr_header(msg, &mut additional);
sequence.map(|sequence| sequence.answer(&mut additional, Time48::now()));
stream.push(ready(Ok(CallResult::new(additional))));
}

Expand Down
Loading

0 comments on commit 24588be

Please sign in to comment.