diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt index fdee5039ade6..437a9b0ba65f 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt @@ -2,21 +2,24 @@ package net.mullvad.talpid import android.net.ConnectivityManager import android.net.LinkProperties -import android.net.Network import android.net.NetworkCapabilities import android.net.NetworkRequest import co.touchlab.kermit.Logger import java.net.InetAddress import kotlin.collections.ArrayList import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.SharingStarted import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.distinctUntilChanged import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.merge import kotlinx.coroutines.flow.onEach +import kotlinx.coroutines.flow.receiveAsFlow import kotlinx.coroutines.flow.scan import kotlinx.coroutines.flow.stateIn +import kotlinx.coroutines.runBlocking import net.mullvad.talpid.model.NetworkState import net.mullvad.talpid.util.NetworkEvent import net.mullvad.talpid.util.RawNetworkState @@ -30,6 +33,7 @@ class ConnectivityListener(private val connectivityManager: ConnectivityManager) get() = _isConnected.value private lateinit var _currentNetworkState: StateFlow + private val resetNetworkState: Channel = Channel() // Used by JNI val currentDefaultNetworkState: NetworkState? @@ -44,51 +48,76 @@ class ConnectivityListener(private val connectivityManager: ConnectivityManager) // the default network may fail if the network on Android 11 // https://issuetracker.google.com/issues/175055271?pli=1 _currentNetworkState = - connectivityManager - .defaultRawNetworkStateFlow() + merge( + connectivityManager.defaultRawNetworkStateFlow(), + resetNetworkState.receiveAsFlow().map { null }, + ) .map { it?.toNetworkState() } - .onEach { notifyDefaultNetworkChange(it) } + .onEach { + Logger.d("NetworkState routes: ${it?.routes}") + notifyDefaultNetworkChange(it) + } .stateIn(scope, SharingStarted.Eagerly, null) _isConnected = hasInternetCapability() .onEach { notifyConnectivityChange(it) } - .stateIn(scope, SharingStarted.Eagerly, false) + .stateIn( + scope, + SharingStarted.Eagerly, + true, // Assume we have internet until we know otherwise + ) + } + + /** + * Invalidates the network state cache. E.g when the VPN is connected or disconnected, and we + * know the last known values not to be correct anymore. + */ + fun invalidateNetworkStateCache() { + // TODO remove runBlocking + runBlocking { resetNetworkState.send(Unit) } } private fun LinkProperties.dnsServersWithoutFallback(): List = dnsServers.filter { it.hostAddress != TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER } - private fun hasInternetCapability(): Flow { - val request = - NetworkRequest.Builder() - .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET) - .addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN) - .build() + private val nonVPNNetworksRequest = + NetworkRequest.Builder().addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN).build() + private fun hasInternetCapability(): Flow { + @Suppress("DEPRECATION") return connectivityManager - .networkEvents(request) - .scan(setOf()) { networks, event -> + .networkEvents(nonVPNNetworksRequest) + .scan( + connectivityManager.allNetworks.associateWith { + connectivityManager.getNetworkCapabilities(it) + } + ) { networks, event -> when (event) { - is NetworkEvent.Available -> { - Logger.d("Network available ${event.network}") - (networks + event.network).also { - Logger.d("Number of networks: ${it.size}") - } - } is NetworkEvent.Lost -> { Logger.d("Network lost ${event.network}") + (networks - event.network).also { Logger.d("Number of networks: ${it.size}") } } + is NetworkEvent.CapabilitiesChanged -> { + Logger.d("Network capabilities changed ${event.network}") + (networks + (event.network to event.networkCapabilities)).also { + Logger.d("Number of networks: ${it.size}") + } + } else -> networks } } - .map { it.isNotEmpty() } .distinctUntilChanged() + .map { it.any { it.value.hasInternetCapability() } } + .onEach { Logger.d("Do we have connectivity? $it") } } + private fun NetworkCapabilities?.hasInternetCapability(): Boolean = + this?.hasCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET) == true + private fun RawNetworkState.toNetworkState(): NetworkState = NetworkState( network.networkHandle, diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt index a143df61322e..1457ff35f441 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt @@ -57,34 +57,22 @@ open class TalpidVpnService : LifecycleVpnService() { // Used by JNI fun openTun(config: TunConfig): CreateTunResult = synchronized(this) { - val tunStatus = activeTunStatus - - if (config == currentTunConfig && tunStatus != null && tunStatus.isOpen) { - tunStatus - } else { - openTunImpl(config) + createTun(config).merge().also { + currentTunConfig = config + activeTunStatus = it } } // Used by JNI - fun openTunForced(config: TunConfig): CreateTunResult = - synchronized(this) { openTunImpl(config) } - - // Used by JNI - fun closeTun(): Unit = synchronized(this) { activeTunStatus = null } + fun closeTun(): Unit = + synchronized(this) { + connectivityListener.invalidateNetworkStateCache() + activeTunStatus = null + } // Used by JNI fun bypass(socket: Int): Boolean = protect(socket) - private fun openTunImpl(config: TunConfig): CreateTunResult { - val newTunStatus = createTun(config).merge() - - currentTunConfig = config - activeTunStatus = newTunStatus - - return newTunStatus - } - private fun createTun( config: TunConfig ): Either = either { @@ -123,6 +111,7 @@ open class TalpidVpnService : LifecycleVpnService() { builder.addDnsServer(FALLBACK_DUMMY_DNS_SERVER) } + connectivityListener.invalidateNetworkStateCache() val vpnInterfaceFd = builder .establishSafe() diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt index fddaa6fb8806..2f150cf67856 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt @@ -109,24 +109,19 @@ fun ConnectivityManager.networkEvents(networkRequest: NetworkRequest): Flow = - defaultNetworkEvents() - .scan( - null as RawNetworkState?, - { state, event -> - return@scan when (event) { - is NetworkEvent.Available -> RawNetworkState(network = event.network) - is NetworkEvent.BlockedStatusChanged -> - state?.copy(blockedStatus = event.blocked) - is NetworkEvent.CapabilitiesChanged -> - state?.copy(networkCapabilities = event.networkCapabilities) - is NetworkEvent.LinkPropertiesChanged -> - state?.copy(linkProperties = event.linkProperties) - is NetworkEvent.Losing -> state?.copy(maxMsToLive = event.maxMsToLive) - is NetworkEvent.Lost -> null - NetworkEvent.Unavailable -> null - } - }, - ) + defaultNetworkEvents().scan(null as RawNetworkState?) { state, event -> state.reduce(event) } + +internal fun RawNetworkState?.reduce(event: NetworkEvent): RawNetworkState? = + when (event) { + is NetworkEvent.Available -> RawNetworkState(network = event.network) + is NetworkEvent.BlockedStatusChanged -> this?.copy(blockedStatus = event.blocked) + is NetworkEvent.CapabilitiesChanged -> + this?.copy(networkCapabilities = event.networkCapabilities) + is NetworkEvent.LinkPropertiesChanged -> this?.copy(linkProperties = event.linkProperties) + is NetworkEvent.Losing -> this?.copy(maxMsToLive = event.maxMsToLive) + is NetworkEvent.Lost -> null + NetworkEvent.Unavailable -> null + } sealed interface NetworkEvent { data class Available(val network: Network) : NetworkEvent diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs index 0ad41b049c14..05024554a377 100644 --- a/talpid-core/src/tunnel_state_machine/connected_state.rs +++ b/talpid-core/src/tunnel_state_machine/connected_state.rs @@ -2,8 +2,6 @@ use futures::channel::{mpsc, oneshot}; use futures::stream::Fuse; use futures::StreamExt; -#[cfg(target_os = "android")] -use talpid_tunnel::tun_provider::Error; use talpid_types::net::{AllowedClients, AllowedEndpoint, TunnelParameters}; use talpid_types::tunnel::{ErrorStateCause, FirewallPolicyError}; use talpid_types::{BoxedError, ErrorExt}; @@ -260,14 +258,7 @@ impl ConnectedState { let consequence = if shared_values.set_allow_lan(allow_lan) { #[cfg(target_os = "android")] { - if let Err(_err) = shared_values.restart_tunnel(false) { - self.disconnect( - shared_values, - AfterDisconnect::Block(ErrorStateCause::StartTunnelError), - ) - } else { - self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) - } + self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) } #[cfg(not(target_os = "android"))] { @@ -298,22 +289,7 @@ impl ConnectedState { let consequence = if shared_values.set_dns_config(servers) { #[cfg(target_os = "android")] { - if let Err(_err) = shared_values.restart_tunnel(false) { - match _err { - Error::InvalidDnsServers(ip_addrs) => self.disconnect( - shared_values, - AfterDisconnect::Block(ErrorStateCause::InvalidDnsServers( - ip_addrs, - )), - ), - _ => self.disconnect( - shared_values, - AfterDisconnect::Block(ErrorStateCause::StartTunnelError), - ), - } - } else { - self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) - } + self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) } #[cfg(not(target_os = "android"))] { @@ -385,17 +361,8 @@ impl ConnectedState { #[cfg(target_os = "android")] Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { if shared_values.set_excluded_paths(paths) { - if let Err(err) = shared_values.restart_tunnel(false) { - let _ = - result_tx.send(Err(crate::split_tunnel::Error::SetExcludedApps(err))); - self.disconnect( - shared_values, - AfterDisconnect::Block(ErrorStateCause::SplitTunnelError), - ) - } else { - let _ = result_tx.send(Ok(())); - self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) - } + let _ = result_tx.send(Ok(())); + self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) } else { let _ = result_tx.send(Ok(())); SameState(self) diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 9060787536db..3b1f7dc76987 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -29,8 +29,6 @@ use crate::tunnel::{self, TunnelMonitor}; pub(crate) type TunnelCloseEvent = Fuse>>; -#[cfg(target_os = "android")] -const MAX_ATTEMPTS_WITH_SAME_TUN: u32 = 5; const MIN_TUNNEL_ALIVE_TIME: Duration = Duration::from_millis(1000); #[cfg(target_os = "windows")] const MAX_ATTEMPT_CREATE_TUN: u32 = 4; @@ -38,6 +36,7 @@ const MAX_ATTEMPT_CREATE_TUN: u32 = 4; const INITIAL_ALLOWED_TUNNEL_TRAFFIC: AllowedTunnelTraffic = AllowedTunnelTraffic::None; /// The tunnel has been started, but it is not established/functional. +#[derive(Debug)] pub struct ConnectingState { tunnel_events: TunnelEventsReceiver, tunnel_parameters: TunnelParameters, @@ -114,20 +113,11 @@ impl ConnectingState { ErrorStateCause::SetFirewallPolicyError(error), ) } else { + // This is magically shimmed in on the side on Android to prep the TunConfig + // with the right DNS servers. On Android DNS is part of creating the VPN + // interface and this call should be part of start_tunnel call instead #[cfg(target_os = "android")] - { - shared_values.prepare_tun_config(false); - if retry_attempt > 0 && retry_attempt % MAX_ATTEMPTS_WITH_SAME_TUN == 0 { - if let Err(error) = - { shared_values.tun_provider.lock().unwrap().open_tun_forced() } - { - log::error!( - "{}", - error.display_chain_with_msg("Failed to recreate tun device") - ); - } - } - } + shared_values.prepare_tun_config(false); let connecting_state = Self::start_tunnel( shared_values.runtime.clone(), @@ -386,14 +376,7 @@ impl ConnectingState { let consequence = if shared_values.set_allow_lan(allow_lan) { #[cfg(target_os = "android")] { - if let Err(_err) = shared_values.restart_tunnel(false) { - self.disconnect( - shared_values, - AfterDisconnect::Block(ErrorStateCause::StartTunnelError), - ) - } else { - self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) - } + self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) } #[cfg(not(target_os = "android"))] self.reset_firewall(shared_values) @@ -427,14 +410,7 @@ impl ConnectingState { let consequence = if shared_values.set_dns_config(servers) { #[cfg(target_os = "android")] { - if let Err(_err) = shared_values.restart_tunnel(false) { - self.disconnect( - shared_values, - AfterDisconnect::Block(ErrorStateCause::StartTunnelError), - ) - } else { - self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) - } + self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) } #[cfg(not(target_os = "android"))] SameState(self) @@ -484,17 +460,8 @@ impl ConnectingState { #[cfg(target_os = "android")] Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { if shared_values.set_excluded_paths(paths) { - if let Err(err) = shared_values.restart_tunnel(false) { - let _ = - result_tx.send(Err(crate::split_tunnel::Error::SetExcludedApps(err))); - self.disconnect( - shared_values, - AfterDisconnect::Block(ErrorStateCause::SplitTunnelError), - ) - } else { - let _ = result_tx.send(Ok(())); - self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) - } + let _ = result_tx.send(Ok(())); + self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) } else { let _ = result_tx.send(Ok(())); SameState(self) diff --git a/talpid-routing/src/unix/android.rs b/talpid-routing/src/unix/android.rs index 137e69c1deb5..9ade7a19f06c 100644 --- a/talpid-routing/src/unix/android.rs +++ b/talpid-routing/src/unix/android.rs @@ -1,5 +1,5 @@ use std::collections::HashSet; -use std::ops::{ControlFlow, Not}; +use std::ops::ControlFlow; use std::sync::Mutex; use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender}; @@ -51,7 +51,7 @@ pub struct RouteManagerImpl { last_state: Option, /// Clients waiting on response to [RouteManagerCommand::WaitForRoutes]. - waiting_for_routes: Vec>, + waiting_for_routes: Vec<(oneshot::Sender<()>, Vec)>, } impl RouteManagerImpl { @@ -64,7 +64,7 @@ impl RouteManagerImpl { // Try to poll for the current network state at startup. // This will most likely be null, but it covers the edge case where a NetworkState - // update has been emitted before we anyone starts to listen for route updates some + // update has been emitted before anyone starts to listen for route updates some // time in the future (when connecting). let last_state = match current_network_state(android_context) { Ok(initial_state) => initial_state, @@ -105,12 +105,17 @@ impl RouteManagerImpl { // update the last known NetworkState self.last_state = network_state; - if has_routes(self.last_state.as_ref()) { - // notify waiting clients that routes exist - for client in self.waiting_for_routes.drain(..) { + // notify waiting clients that routes exist + let mut unused_routes: Vec<(oneshot::Sender<()>, Vec)> = Vec::new(); + let ret = for (client, expected_routes) in self.waiting_for_routes.drain(..) { + if has_routes(self.last_state.as_ref(), expected_routes.clone()) { let _ = client.send(()); + } else { + unused_routes.push((client, expected_routes)); } - } + }; + self.waiting_for_routes = unused_routes; + ret } } } @@ -126,31 +131,44 @@ impl RouteManagerImpl { let _ = tx.send(()); return ControlFlow::Break(()); } - RouteManagerCommand::WaitForRoutes(response_tx) => { + RouteManagerCommand::WaitForRoutes(response_tx, expected_routes) => { // check if routes have already been configured on the Android system. // otherwise, register a listener for network state changes. // routes may come in at any moment in the future. - if has_routes(self.last_state.as_ref()) { + if has_routes(self.last_state.as_ref(), expected_routes.clone()) { let _ = response_tx.send(()); } else { - self.waiting_for_routes.push(response_tx); + self.waiting_for_routes.push((response_tx, expected_routes)); } } + RouteManagerCommand::ClearRoutes(tx) => { + self.clear_routes(); + let _ = tx.send(()); + } } ControlFlow::Continue(()) } + + pub fn clear_routes(&mut self) { + self.last_state = None; + } } -/// Check whether the [NetworkState] contains any routes. +/// Check whether the [NetworkState] contains expected routes. /// -/// Since we are the ones telling Android what routes to set, we make the assumption that: -/// If any routes exist whatsoever, they are the the routes we specified. -fn has_routes(state: Option<&NetworkState>) -> bool { +/// Matches the routes reported from Android and checks if all the routes we expect to be there is +/// present. +fn has_routes(state: Option<&NetworkState>, expected_routes: Vec) -> bool { let Some(network_state) = state else { return false; }; - configured_routes(network_state).is_empty().not() + + let routes = configured_routes(network_state); + if routes.is_empty() { + return false; + } + routes.is_superset(&HashSet::from_iter(expected_routes)) } fn configured_routes(state: &NetworkState) -> HashSet { diff --git a/talpid-routing/src/unix/mod.rs b/talpid-routing/src/unix/mod.rs index 5aedc9626ea9..63a68745e9e6 100644 --- a/talpid-routing/src/unix/mod.rs +++ b/talpid-routing/src/unix/mod.rs @@ -37,6 +37,8 @@ mod imp; #[path = "android.rs"] mod imp; +#[cfg(target_os = "android")] +use crate::Route; #[cfg(any(target_os = "macos", target_os = "linux"))] pub use imp::Error as PlatformError; @@ -103,7 +105,8 @@ pub(crate) enum RouteManagerCommand { #[cfg(target_os = "android")] #[derive(Debug)] pub(crate) enum RouteManagerCommand { - WaitForRoutes(oneshot::Sender<()>), + ClearRoutes(oneshot::Sender<()>), + WaitForRoutes(oneshot::Sender<()>, Vec), Shutdown(oneshot::Sender<()>), } @@ -215,7 +218,7 @@ impl RouteManagerHandle { /// This function is guaranteed to *not* wait for longer than 2 seconds. /// Please, see the implementation of this function for further details. #[cfg(target_os = "android")] - pub async fn wait_for_routes(&self) -> Result<(), Error> { + pub async fn wait_for_routes(&self, expect_routes: Vec) -> Result<(), Error> { use std::time::Duration; use tokio::time::timeout; /// Maximum time to wait for routes to come up. The expected mean time is low (~200 ms), but @@ -224,7 +227,7 @@ impl RouteManagerHandle { let (result_tx, result_rx) = oneshot::channel(); self.tx - .unbounded_send(RouteManagerCommand::WaitForRoutes(result_tx)) + .unbounded_send(RouteManagerCommand::WaitForRoutes(result_tx, expect_routes)) .map_err(|_| Error::RouteManagerDown)?; timeout(WAIT_FOR_ROUTES_TIMEOUT, result_rx) @@ -247,6 +250,18 @@ impl RouteManagerHandle { Ok(()) } + /// xD + /// + #[cfg(target_os = "android")] + pub async fn clear_android_routes(&self) -> Result<(), Error> { + let (result_tx, result_rx) = oneshot::channel(); + self.tx + .unbounded_send(RouteManagerCommand::ClearRoutes(result_tx)) + .map_err(|_| Error::RouteManagerDown)?; + let _ = result_rx.await; + Ok(()) + } + /// Listen for non-tunnel default route changes. #[cfg(target_os = "macos")] pub async fn default_route_listener( diff --git a/talpid-tunnel/src/tun_provider/android/mod.rs b/talpid-tunnel/src/tun_provider/android/mod.rs index f285b4a64ca1..96e0b8fe57f1 100644 --- a/talpid-tunnel/src/tun_provider/android/mod.rs +++ b/talpid-tunnel/src/tun_provider/android/mod.rs @@ -16,6 +16,7 @@ use std::{ os::unix::io::{AsRawFd, RawFd}, sync::Arc, }; +use talpid_routing::Route; use talpid_types::net::{ALLOWED_LAN_MULTICAST_NETS, ALLOWED_LAN_NETS}; use talpid_types::{android::AndroidContext, ErrorExt}; @@ -65,6 +66,7 @@ pub struct AndroidTunProvider { class: GlobalRef, object: GlobalRef, config: TunConfig, + current_config: Option<(VpnServiceConfig, RawFd)>, } impl AndroidTunProvider { @@ -83,6 +85,7 @@ impl AndroidTunProvider { class: talpid_vpn_service_class, object: context.vpn_service, config, + current_config: None, } } @@ -93,51 +96,65 @@ impl AndroidTunProvider { } /// Open a tunnel with the current configuration. - pub fn open_tun(&mut self) -> Result { + pub fn open_tun(&mut self) -> Result<(VpnServiceTun, bool), Error> { self.open_tun_inner("openTun") } /// Open a tunnel with the current configuration. - /// Force recreation even if the tunnel config hasn't changed. - pub fn open_tun_forced(&mut self) -> Result { - self.open_tun_inner("openTunForced") - } - - /// Open a tunnel with the current configuration. - fn open_tun_inner(&mut self, get_tun_func_name: &'static str) -> Result { - let tun_fd = self.open_tun_fd(get_tun_func_name)?; + fn open_tun_inner( + &mut self, + get_tun_func_name: &'static str, + ) -> Result<(VpnServiceTun, bool), Error> { + let (tun_fd, reuse) = self.open_tun_fd(get_tun_func_name)?; + log::debug!("DEBUG: Opening tun: {}", tun_fd); let jvm = unsafe { JavaVM::from_raw(self.jvm.get_java_vm_pointer()) } .map_err(Error::CloneJavaVm)?; - Ok(VpnServiceTun { - tunnel: tun_fd, - jvm, - class: self.class.clone(), - object: self.object.clone(), - }) + Ok(( + VpnServiceTun { + tunnel: tun_fd, + jvm, + class: self.class.clone(), + object: self.object.clone(), + }, + reuse, + )) } - fn open_tun_fd(&self, get_tun_func_name: &'static str) -> Result { + fn open_tun_fd(&mut self, get_tun_func_name: &'static str) -> Result<(RawFd, bool), Error> { let config = VpnServiceConfig::new(self.config.clone()); - let env = self.env()?; - let java_config = config.into_java(&env); - - let result = self.call_method( - get_tun_func_name, - "(Lnet/mullvad/talpid/model/TunConfig;)Lnet/mullvad/talpid/model/CreateTunResult;", - JavaType::Object("net/mullvad/talpid/model/CreateTunResult".to_owned()), - &[JValue::Object(java_config.as_obj())], - )?; - - match result { - JValue::Object(result) => CreateTunResult::from_java(&env, result).into(), - value => Err(Error::InvalidMethodResult( + // If we are recreating the same tunnel we return the same file descriptor to avoid calling + // open_tun in android since it may cause leaks. + if let Some(current_config) = &self.current_config { + if current_config.0 == config { + return Ok((current_config.1, false)); + } + } + let create_result = { + let env = self.env()?; + let java_config = config.clone().into_java(&env); + let result = self.call_method( get_tun_func_name, - format!("{:?}", value), - )), + "(Lnet/mullvad/talpid/model/TunConfig;)Lnet/mullvad/talpid/model/CreateTunResult;", + JavaType::Object("net/mullvad/talpid/model/CreateTunResult".to_owned()), + &[JValue::Object(java_config.as_obj())], + )?; + + match result { + JValue::Object(result) => CreateTunResult::from_java(&env, result).into(), + value => Err(Error::InvalidMethodResult( + get_tun_func_name, + format!("{:?}", value), + )), + } + .map(|raw_fd| (raw_fd, true)) + }; + if let Ok(create_result) = create_result { + self.current_config = Some((config, create_result.0)); } + create_result } /// Close currently active tunnel device. @@ -158,6 +175,9 @@ impl AndroidTunProvider { "{}", error.display_chain_with_msg("Failed to close the tunnel") ); + } else { + // Remove the cache of config + self.current_config = None; } } @@ -188,6 +208,14 @@ impl AndroidTunProvider { } } + pub fn real_routes(&self) -> Vec { + self.config + .real_routes() + .iter() + .map(|ip_network| Route::new(ip_network.clone())) + .collect() + } + fn call_method( &self, name: &'static str, @@ -221,7 +249,7 @@ impl AndroidTunProvider { /// Configuration to use for VpnService #[derive(Clone, Debug, Eq, PartialEq, IntoJava)] #[jnix(class_name = "net.mullvad.talpid.model.TunConfig")] -struct VpnServiceConfig { +pub struct VpnServiceConfig { /// IP addresses for the tunnel interface. pub addresses: Vec, @@ -318,7 +346,7 @@ impl VpnServiceConfig { #[derive(Clone, Debug, Eq, PartialEq, IntoJava)] #[jnix(package = "net.mullvad.talpid.model")] -struct InetNetwork { +pub struct InetNetwork { address: IpAddr, prefix: i16, } @@ -332,9 +360,19 @@ impl From for InetNetwork { } } +impl From<&InetNetwork> for IpNetwork { + fn from(inet_network: &InetNetwork) -> Self { + IpNetwork::new( + inet_network.address, + inet_network.prefix.to_be_bytes().last().unwrap().clone(), + ) + .unwrap() + } +} + /// Handle to a tunnel device on Android. pub struct VpnServiceTun { - tunnel: RawFd, + pub tunnel: RawFd, jvm: JavaVM, class: GlobalRef, object: GlobalRef, diff --git a/talpid-tunnel/src/tun_provider/mod.rs b/talpid-tunnel/src/tun_provider/mod.rs index 1bf4e1abb483..9ce767aa8856 100644 --- a/talpid-tunnel/src/tun_provider/mod.rs +++ b/talpid-tunnel/src/tun_provider/mod.rs @@ -1,3 +1,5 @@ +#[cfg(target_os = "android")] +use crate::tun_provider::imp::VpnServiceConfig; use cfg_if::cfg_if; use ipnetwork::IpNetwork; use std::{ @@ -73,6 +75,17 @@ impl TunConfig { } servers } + + /// Routes to configure for the tunnel. + #[cfg(target_os = "android")] + pub fn real_routes(&self) -> Vec { + VpnServiceConfig::new(self.clone()) + .routes + .clone() + .iter() + .map(|x| IpNetwork::from(x)) + .collect() + } } /// Return a tunnel configuration that routes all traffic inside the tunnel. diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index fe1a848e9a74..05b55c173796 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -28,6 +28,8 @@ use talpid_tunnel::{ tun_provider::TunProvider, EventHook, TunnelArgs, TunnelEvent, TunnelMetadata, }; +#[cfg(target_os = "android")] +use talpid_routing::RouteManagerHandle; #[cfg(daita)] use talpid_tunnel_config_client::DaitaSettings; use talpid_types::{ @@ -434,6 +436,7 @@ impl WireguardMonitor { &config, log_path, args.tun_provider.clone(), + args.route_manager, // In case we should negotiate an ephemeral peer, we should specify via AllowedIPs // that we only allows traffic to/from the gateway. This is only needed on Android // since we lack a firewall there. @@ -465,13 +468,6 @@ impl WireguardMonitor { .on_event(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)) .await; - // Wait for routes to come up - args.route_manager - .wait_for_routes() - .await - .map_err(Error::SetupRoutingError) - .map_err(CloseMsg::SetupError)?; - if should_negotiate_ephemeral_peer { let ephemeral_obfs_sender = close_obfs_sender.clone(); @@ -743,12 +739,13 @@ impl WireguardMonitor { config: &Config, log_path: Option<&Path>, #[cfg(unix)] tun_provider: Arc>, + #[cfg(target_os = "android")] route_manager: RouteManagerHandle, #[cfg(windows)] setup_done_tx: mpsc::Sender>, #[cfg(windows)] route_manager: talpid_routing::RouteManagerHandle, #[cfg(target_os = "android")] gateway_only: bool, #[cfg(target_os = "android")] cancel_receiver: connectivity::CancelReceiver, ) -> Result { - #[cfg(unix)] + #[cfg(all(unix, not(target_os = "android")))] let routes = config .get_tunnel_destinations() .flat_map(Self::replace_default_prefixes); @@ -770,6 +767,7 @@ impl WireguardMonitor { // tunnel to where the ephemeral peer resides. // // Refer to `docs/architecture.md` for details on how to use multihop + PQ. + log::debug!("patching allowed ips"); #[cfg(target_os = "android")] let config = Self::patch_allowed_ips(config, gateway_only); @@ -780,7 +778,7 @@ impl WireguardMonitor { exit_peer, log_path, tun_provider, - routes, + route_manager, cancel_receiver, ) .await @@ -791,7 +789,7 @@ impl WireguardMonitor { &config, log_path, tun_provider, - routes, + route_manager, cancel_receiver, ) .await @@ -819,7 +817,6 @@ impl WireguardMonitor { .block_on(self.event_hook.on_event(TunnelEvent::Down)); self.stop_tunnel(); - wait_result } @@ -969,6 +966,7 @@ impl WireguardMonitor { } /// Replace default (0-prefix) routes with more specific routes. + #[cfg(all(unix, not(target_os = "android")))] fn replace_default_prefixes(network: ipnetwork::IpNetwork) -> Vec { #[cfg(windows)] if network.prefix() == 0 { diff --git a/talpid-wireguard/src/wireguard_go/mod.rs b/talpid-wireguard/src/wireguard_go/mod.rs index 813490899a6e..d1b0495a60ba 100644 --- a/talpid-wireguard/src/wireguard_go/mod.rs +++ b/talpid-wireguard/src/wireguard_go/mod.rs @@ -1,15 +1,16 @@ #[cfg(target_os = "android")] use super::config; +#[cfg(target_os = "android")] use super::{ stats::{Stats, StatsMap}, - Config, Tunnel, TunnelError, + CloseMsg, Config, Error, Tunnel, TunnelError, }; #[cfg(target_os = "linux")] use crate::config::MULLVAD_INTERFACE_NAME; #[cfg(target_os = "android")] use crate::connectivity; use crate::logging::{clean_up_logging, initialize_logging}; -#[cfg(unix)] +#[cfg(all(unix, not(target_os = "android")))] use ipnetwork::IpNetwork; #[cfg(daita)] use std::ffi::CString; @@ -23,6 +24,8 @@ use std::{ pin::Pin, }; #[cfg(target_os = "android")] +use talpid_routing::RouteManagerHandle; +#[cfg(target_os = "android")] use talpid_tunnel::tun_provider::Error as TunProviderError; #[cfg(not(target_os = "windows"))] use talpid_tunnel::tun_provider::{Tun, TunProvider}; @@ -115,7 +118,7 @@ impl WgGoTunnel { let log_path = state._logging_context.path.clone(); let cancel_receiver = state.cancel_receiver.clone(); let tun_provider = Arc::clone(&state.tun_provider); - let routes = config.get_tunnel_destinations(); + let route_manager = &state.route_manager.clone(); match self { WgGoTunnel::Multihop(state) if !config.is_multihop() => { @@ -124,7 +127,7 @@ impl WgGoTunnel { config, log_path.as_deref(), tun_provider, - routes, + route_manager.clone(), cancel_receiver, ) .await @@ -136,22 +139,19 @@ impl WgGoTunnel { &config.exit_peer.clone().unwrap().clone(), log_path.as_deref(), tun_provider, - routes, + route_manager.clone(), cancel_receiver, ) .await } WgGoTunnel::Singlehop(mut state) => { state.set_config(config.clone())?; - // HACK: Check if the tunnel is working by sending a ping in the tunnel. let new_state = WgGoTunnel::Singlehop(state); - new_state.ensure_tunnel_is_running().await?; Ok(new_state) } WgGoTunnel::Multihop(mut state) => { state.set_config(config.clone())?; let new_state = WgGoTunnel::Multihop(state); - new_state.ensure_tunnel_is_running().await?; Ok(new_state) } } @@ -173,6 +173,8 @@ pub(crate) struct WgGoTunnelState { _logging_context: LoggingContext, #[cfg(target_os = "android")] tun_provider: Arc>, + #[cfg(target_os = "android")] + route_manager: RouteManagerHandle, #[cfg(daita)] config: Config, /// This is used to cancel the connectivity checks that occur when toggling multihop @@ -344,7 +346,40 @@ impl WgGoTunnel { } } - #[cfg(unix)] + #[cfg(target_os = "android")] + fn get_tunnel( + tun_provider: Arc>, + config: &Config, + ) -> Result<(Tun, RawFd, bool)> { + let mut tun_provider = tun_provider.lock().unwrap(); + let mut last_error = None; + let tun_config = tun_provider.config_mut(); + + tun_config.addresses = config.tunnel.addresses.clone(); + tun_config.ipv4_gateway = config.ipv4_gateway; + tun_config.ipv6_gateway = config.ipv6_gateway; + tun_config.mtu = config.mtu; + tun_config.routes = vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()]; + + for _ in 1..=MAX_PREPARE_TUN_ATTEMPTS { + let (tunnel_device, is_new_tunnel) = tun_provider + .open_tun() + .map_err(TunnelError::SetupTunnelDevice)?; + + match nix::unistd::dup(tunnel_device.as_raw_fd()) { + Ok(fd) => return Ok((tunnel_device, fd, is_new_tunnel)), + #[cfg(not(target_os = "macos"))] + Err(error @ nix::errno::Errno::EBADFD) => last_error = Some(error), + Err(error @ nix::errno::Errno::EBADF) => last_error = Some(error), + Err(error) => return Err(TunnelError::FdDuplicationError(error)), + } + } + + Err(TunnelError::FdDuplicationError( + last_error.expect("Should be collected in loop"), + )) + } + #[cfg(any(target_os = "linux", target_os = "macos"))] fn get_tunnel( tun_provider: Arc>, config: &Config, @@ -395,11 +430,13 @@ impl WgGoTunnel { config: &Config, log_path: Option<&Path>, tun_provider: Arc>, - routes: impl Iterator, + route_manager: RouteManagerHandle, cancel_receiver: connectivity::CancelReceiver, ) -> Result { - let (mut tunnel_device, tunnel_fd) = - Self::get_tunnel(Arc::clone(&tun_provider), config, routes)?; + let _ = route_manager.clear_android_routes().await; + + let (mut tunnel_device, tunnel_fd, is_new_tunnel) = + Self::get_tunnel(Arc::clone(&tun_provider), config)?; let interface_name: String = tunnel_device .interface_name() @@ -427,12 +464,17 @@ impl WgGoTunnel { _tunnel_device: tunnel_device, _logging_context: logging_context, tun_provider, + route_manager, #[cfg(daita)] config: config.clone(), cancel_receiver, }); - // HACK: Check if the tunnel is working by sending a ping in the tunnel. + if is_new_tunnel { + tunnel.wait_for_routes().await?; + } + + // This seemingly fixes the GO crash we see tunnel.ensure_tunnel_is_running().await?; Ok(tunnel) @@ -443,11 +485,13 @@ impl WgGoTunnel { exit_peer: &PeerConfig, log_path: Option<&Path>, tun_provider: Arc>, - routes: impl Iterator, + route_manager: RouteManagerHandle, cancel_receiver: connectivity::CancelReceiver, ) -> Result { - let (mut tunnel_device, tunnel_fd) = - Self::get_tunnel(Arc::clone(&tun_provider), config, routes)?; + let _ = route_manager.clear_android_routes().await; + + let (mut tunnel_device, tunnel_fd, is_new_tunnel) = + Self::get_tunnel(Arc::clone(&tun_provider), config)?; let interface_name: String = tunnel_device .interface_name() @@ -491,12 +535,17 @@ impl WgGoTunnel { _tunnel_device: tunnel_device, _logging_context: logging_context, tun_provider, + route_manager, #[cfg(daita)] config: config.clone(), cancel_receiver: cancel_receiver.clone(), }); - // HACK: Check if the tunnel is working by sending a ping in the tunnel. + if is_new_tunnel { + tunnel.wait_for_routes().await?; + } + + // This seemingly fixes the GO crash we see tunnel.ensure_tunnel_is_running().await?; Ok(tunnel) @@ -517,6 +566,24 @@ impl WgGoTunnel { /// There is a brief period of time between setting up a Wireguard-go tunnel and the tunnel being ready to serve /// traffic. This function blocks until the tunnel starts to serve traffic or until [connectivity::Check] times out. + async fn wait_for_routes(&self) -> Result<()> { + let state = self.as_state(); + + let expected_routes = state.tun_provider.lock().unwrap().real_routes(); + + // TODO HANDLE UNWRAP + // Wait for routes to come up + state + .route_manager + .clone() + .wait_for_routes(expected_routes) + .await + .map_err(Error::SetupRoutingError) + .map_err(CloseMsg::SetupError) + .unwrap(); + + Ok(()) + } async fn ensure_tunnel_is_running(&self) -> Result<()> { let state = self.as_state(); let addr = state.config.ipv4_gateway;