Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix wait for routes #7659

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.distinctUntilChanged
import kotlinx.coroutines.flow.drop
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.merge
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.flow.scan
import kotlinx.coroutines.flow.stateIn
Expand All @@ -23,7 +25,10 @@ import net.mullvad.talpid.util.RawNetworkState
import net.mullvad.talpid.util.defaultRawNetworkStateFlow
import net.mullvad.talpid.util.networkEvents

class ConnectivityListener(private val connectivityManager: ConnectivityManager) {
class ConnectivityListener(
private val connectivityManager: ConnectivityManager,
private val resetDnsFlow: Flow<Unit>,
) {
private lateinit var _isConnected: StateFlow<Boolean>
// Used by JNI
val isConnected
Expand All @@ -44,49 +49,57 @@ class ConnectivityListener(private val connectivityManager: ConnectivityManager)
// the default network may fail if the network on Android 11
// https://issuetracker.google.com/issues/175055271?pli=1
_currentNetworkState =
connectivityManager
.defaultRawNetworkStateFlow()
merge(connectivityManager.defaultRawNetworkStateFlow(), resetDnsFlow.map { null })
.map { it?.toNetworkState() }
.onEach { notifyDefaultNetworkChange(it) }
.onEach {
Logger.d("NetworkState routes: ${it?.routes}")
notifyDefaultNetworkChange(it)
}
.stateIn(scope, SharingStarted.Eagerly, null)

@Suppress("DEPRECATION")
_isConnected =
hasInternetCapability()
.onEach { notifyConnectivityChange(it) }
.stateIn(scope, SharingStarted.Eagerly, false)
.stateIn(
scope,
SharingStarted.Eagerly,
true, // Assume we have internet until we know otherwise
)
}

private fun LinkProperties.dnsServersWithoutFallback(): List<InetAddress> =
dnsServers.filter { it.hostAddress != TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER }

private val nonVPNNetworksRequest =
NetworkRequest.Builder().addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN).build()

private fun hasInternetCapability(): Flow<Boolean> {
val request =
NetworkRequest.Builder()
.addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
.addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
.build()

return connectivityManager
.networkEvents(request)
.scan(setOf<Network>()) { networks, event ->
.networkEvents(nonVPNNetworksRequest)
.scan(mapOf<Network, NetworkCapabilities?>()) { networks, event ->
when (event) {
is NetworkEvent.Available -> {
Logger.d("Network available ${event.network}")
(networks + event.network).also {
Logger.d("Number of networks: ${it.size}")
}
}
is NetworkEvent.Lost -> {
Logger.d("Network lost ${event.network}")

(networks - event.network).also {
Logger.d("Number of networks: ${it.size}")
}
}
is NetworkEvent.CapabilitiesChanged -> {
Logger.d("Network capabilities changed ${event.network}")
(networks + (event.network to event.networkCapabilities)).also {
Logger.d("Number of networks: ${it.size}")
}
}
else -> networks
}
}
.map { it.isNotEmpty() }
.distinctUntilChanged()
.drop(1)
.map { it.any { it.value?.hasCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET) == true} }
.onEach { Logger.d("Do we have connectivity? $it") }
}

private fun RawNetworkState.toNetworkState(): NetworkState =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ import java.net.Inet4Address
import java.net.Inet6Address
import java.net.InetAddress
import kotlin.properties.Delegates.observable
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.flow.receiveAsFlow
import kotlinx.coroutines.runBlocking
import net.mullvad.mullvadvpn.lib.common.util.establishSafe
import net.mullvad.mullvadvpn.lib.common.util.prepareVpnSafe
import net.mullvad.mullvadvpn.lib.model.PrepareError
Expand All @@ -38,10 +41,12 @@ open class TalpidVpnService : LifecycleVpnService() {
}

if (oldTunFd != null) {
Logger.d("Closing old tunFd $oldTunFd")
ParcelFileDescriptor.adoptFd(oldTunFd).close()
}
}

private val resetDnsChannel = Channel<Unit>()
private var currentTunConfig: TunConfig? = null

// Used by JNI
Expand All @@ -50,45 +55,43 @@ open class TalpidVpnService : LifecycleVpnService() {
@CallSuper
override fun onCreate() {
super.onCreate()
connectivityListener = ConnectivityListener(getSystemService<ConnectivityManager>()!!)
connectivityListener =
ConnectivityListener(
getSystemService<ConnectivityManager>()!!,
resetDnsChannel.receiveAsFlow(),
)
connectivityListener.register(lifecycleScope)
}

// Used by JNI
fun openTun(config: TunConfig): CreateTunResult =
synchronized(this) {
val tunStatus = activeTunStatus

if (config == currentTunConfig && tunStatus != null && tunStatus.isOpen) {
tunStatus
} else {
openTunImpl(config)
Logger.d("TalpidVpnService.openTun")
createTun(config).merge().also {
currentTunConfig = config
activeTunStatus = it
}
}

// Used by JNI
fun openTunForced(config: TunConfig): CreateTunResult =
synchronized(this) { openTunImpl(config) }

// Used by JNI
fun closeTun(): Unit = synchronized(this) { activeTunStatus = null }
fun closeTun(): Unit =
synchronized(this) {
Logger.d("TalpidVpnService.closeTun")
runBlocking { resetDnsChannel.send(Unit) }
activeTunStatus = null
}

// Used by JNI
fun bypass(socket: Int): Boolean = protect(socket)

private fun openTunImpl(config: TunConfig): CreateTunResult {
val newTunStatus = createTun(config).merge()

currentTunConfig = config
activeTunStatus = newTunStatus

return newTunStatus
fun bypass(socket: Int): Boolean {
Logger.d("TalpidVpnService.bypass")
return protect(socket)
}

private fun createTun(
config: TunConfig
): Either<CreateTunResult.Error, CreateTunResult.Success> = either {
prepareVpnSafe().mapLeft { it.toCreateTunError() }.bind()
Logger.d("TalpidVpnService.createTun $config")

val builder = Builder()
builder.setMtu(config.mtu)
Expand Down Expand Up @@ -123,12 +126,14 @@ open class TalpidVpnService : LifecycleVpnService() {
builder.addDnsServer(FALLBACK_DUMMY_DNS_SERVER)
}

runBlocking { resetDnsChannel.send(Unit) }
val vpnInterfaceFd =
builder
.establishSafe()
.onLeft { Logger.w("Failed to establish tunnel $it") }
.mapLeft { EstablishError }
.bind()
Logger.d("Establish!")

val tunFd = vpnInterfaceFd.detachFd()

Expand Down
3 changes: 3 additions & 0 deletions talpid-core/src/tunnel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ impl TunnelMonitor {
log_dir: &Option<path::PathBuf>,
args: TunnelArgs<'_>,
) -> Result<Self> {

log::debug!("DEBUG: TunnelMonitor::start");
Self::ensure_ipv6_can_be_used_if_enabled(tunnel_parameters)?;
let log_file = Self::prepare_tunnel_log_file(tunnel_parameters, log_dir)?;

Expand Down Expand Up @@ -182,6 +184,7 @@ impl TunnelMonitor {
log: Option<path::PathBuf>,
args: TunnelArgs<'_>,
) -> Result<Self> {
log::debug!("DEBUG: start_wireguard_tunnel");
let monitor = talpid_wireguard::WireguardMonitor::start(params, log.as_deref(), args)?;
Ok(TunnelMonitor {
monitor: InternalTunnelMonitor::Wireguard(monitor),
Expand Down
41 changes: 4 additions & 37 deletions talpid-core/src/tunnel_state_machine/connected_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ use futures::channel::{mpsc, oneshot};
use futures::stream::Fuse;
use futures::StreamExt;

#[cfg(target_os = "android")]
use talpid_tunnel::tun_provider::Error;
use talpid_types::net::{AllowedClients, AllowedEndpoint, TunnelParameters};
use talpid_types::tunnel::{ErrorStateCause, FirewallPolicyError};
use talpid_types::{BoxedError, ErrorExt};
Expand Down Expand Up @@ -260,14 +258,7 @@ impl ConnectedState {
let consequence = if shared_values.set_allow_lan(allow_lan) {
#[cfg(target_os = "android")]
{
if let Err(_err) = shared_values.restart_tunnel(false) {
self.disconnect(
shared_values,
AfterDisconnect::Block(ErrorStateCause::StartTunnelError),
)
} else {
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
}
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
}
#[cfg(not(target_os = "android"))]
{
Expand Down Expand Up @@ -298,22 +289,7 @@ impl ConnectedState {
let consequence = if shared_values.set_dns_config(servers) {
#[cfg(target_os = "android")]
{
if let Err(_err) = shared_values.restart_tunnel(false) {
match _err {
Error::InvalidDnsServers(ip_addrs) => self.disconnect(
shared_values,
AfterDisconnect::Block(ErrorStateCause::InvalidDnsServers(
ip_addrs,
)),
),
_ => self.disconnect(
shared_values,
AfterDisconnect::Block(ErrorStateCause::StartTunnelError),
),
}
} else {
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
}
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
}
#[cfg(not(target_os = "android"))]
{
Expand Down Expand Up @@ -385,17 +361,8 @@ impl ConnectedState {
#[cfg(target_os = "android")]
Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => {
if shared_values.set_excluded_paths(paths) {
if let Err(err) = shared_values.restart_tunnel(false) {
let _ =
result_tx.send(Err(crate::split_tunnel::Error::SetExcludedApps(err)));
self.disconnect(
shared_values,
AfterDisconnect::Block(ErrorStateCause::SplitTunnelError),
)
} else {
let _ = result_tx.send(Ok(()));
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
}
let _ = result_tx.send(Ok(()));
self.disconnect(shared_values, AfterDisconnect::Reconnect(0))
} else {
let _ = result_tx.send(Ok(()));
SameState(self)
Expand Down
Loading
Loading