Skip to content

Commit

Permalink
[RefCounted and friends] Fix type safety of ref-counted types.
Browse files Browse the repository at this point in the history
Previously, `RefCountedPtr<>` and `WeakRefCountedPtr<>` incorrectly allowed
implicit casting of any type to any other type.  This hadn't caused a
problem until recently, but now that it has, we need to fix it.  I have
fixed this by changing these smart pointer types to allow type
conversions only when the type used is convertible to the type of the
smart pointer.  This means that if `Subclass` inherits from `Base`, then
we can set a `RefCountedPtr<BaseClass>` to a value of type
`RefCountedPtr<Subclass>`, but we cannot do the reverse.

We had been (ab)using this bug to make it more convenient to deal with
down-casting in subclasses of ref-counted types.  For example, because
`Resolver` inherits from `InternallyRefCounted<Resolver>`, calling
`Ref()` on a subclass of `Resolver` will return `RefCountedPtr<Resolver>`
rather than returning the subclass's type.  The ability to implicitly
convert to the subclass type made this a bit easier to deal with.  Now
that that ability is gone, we need a different way of dealing with that
problem.

I considered several ways of dealing with this, but none of them are
quite as ergonomic as I would ideally like.  For now, I've settled on
requiring callers to explicitly down-cast as needed, although I have
provided some utility functions to make this slightly easier:

- `RefCounted<>`, `InternallyRefCounted<>`, and `DualRefCounted<>` all
  provide a templated `RefAsSubclass<>()` method that will return a new
  ref as a subclass.  The type used with `RefAsSubclass()` must be a
  subclass of the type passed to `RefCounted<>`, `InternallyRefCounted<>`,
  or `DualRefCounted<>`.
- In addition, `DualRefCounted<>` provides a templated `WeakRefAsSubclass<T>()`
  method.  This is the same as `RefAsSubclass()`, except that it returns
  a weak ref instead of a strong ref.
- In `RefCountedPtr<>`, I have added a new `Ref()` method that takes
  debug tracing parameters.  This can be used instead of calling `Ref()`
  on the underlying object in cases where the caller already has a
  `RefCountedPtr<>` and is calling `Ref()` only to specify the debug
  tracing parameters.  Using this method on `RefCountedPtr<>` is more
  ergonomic, because the smart pointer is already using the right
  subclass, so no down-casting is needed.
- In `WeakRefCountedPtr<>`, I have added a new `WeakRef()` method that
  takes debug tracing parameters.  This is the same as the new `Ref()`
  method on `RefCountedPtr<>`.
- In both `RefCountedPtr<>` and `WeakRefCountedPtr<>`, I have added a
  templated `TakeAsSubclass<>()` method that takes the ref out of the
  smart pointer and returns a new smart pointer of the down-casted type.
  Just as with the `RefAsSubclass()` method above, the type used with
  `TakeAsSubclass()` must be a subclass of the type passed to
  `RefCountedPtr<>` or `WeakRefCountedPtr<>`.

Note that I have *not* provided an `AsSubclass<>()` variant of the
`RefIfNonZero()` methods.  Those methods are used relatively rarely, so
it's not as important for them to be quite so ergonomic.  Callers of
these methods that need to down-cast can use
`RefIfNonZero().TakeAsSubclass<>()`.

PiperOrigin-RevId: 592327447
  • Loading branch information
markdroth authored and copybara-github committed Dec 19, 2023
1 parent 85cede5 commit 3e785d3
Show file tree
Hide file tree
Showing 53 changed files with 487 additions and 295 deletions.
13 changes: 8 additions & 5 deletions src/core/ext/filters/client_channel/client_channel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -714,8 +714,9 @@ class ClientChannel::SubchannelWrapper : public SubchannelInterface {
ABSL_EXCLUSIVE_LOCKS_REQUIRED(*chand_->work_serializer_) {
auto& watcher_wrapper = watcher_map_[watcher.get()];
GPR_ASSERT(watcher_wrapper == nullptr);
watcher_wrapper = new WatcherWrapper(std::move(watcher),
Ref(DEBUG_LOCATION, "WatcherWrapper"));
watcher_wrapper = new WatcherWrapper(
std::move(watcher),
RefAsSubclass<SubchannelWrapper>(DEBUG_LOCATION, "WatcherWrapper"));
subchannel_->WatchConnectivityState(
RefCountedPtr<Subchannel::ConnectivityStateWatcherInterface>(
watcher_wrapper));
Expand Down Expand Up @@ -919,7 +920,8 @@ ClientChannel::ExternalConnectivityWatcher::ExternalConnectivityWatcher(
GPR_ASSERT(chand->external_watchers_[on_complete] == nullptr);
// Store a ref to the watcher in the external_watchers_ map.
chand->external_watchers_[on_complete] =
Ref(DEBUG_LOCATION, "AddWatcherToExternalWatchersMapLocked");
RefAsSubclass<ExternalConnectivityWatcher>(
DEBUG_LOCATION, "AddWatcherToExternalWatchersMapLocked");
}
// Pass the ref from creating the object to Start().
chand_->work_serializer_->Run(
Expand Down Expand Up @@ -3421,7 +3423,8 @@ void ClientChannel::FilterBasedLoadBalancedCall::TryPick(bool was_queued) {

void ClientChannel::FilterBasedLoadBalancedCall::OnAddToQueueLocked() {
// Register call combiner cancellation callback.
lb_call_canceller_ = new LbQueuedCallCanceller(Ref());
lb_call_canceller_ =
new LbQueuedCallCanceller(RefAsSubclass<FilterBasedLoadBalancedCall>());
}

void ClientChannel::FilterBasedLoadBalancedCall::RetryPickLocked() {
Expand Down Expand Up @@ -3510,7 +3513,7 @@ ClientChannel::PromiseBasedLoadBalancedCall::MakeCallPromise(
}
// Extract peer name from server initial metadata.
call_args.server_initial_metadata->InterceptAndMap(
[self = RefCountedPtr<PromiseBasedLoadBalancedCall>(lb_call->Ref())](
[self = lb_call->RefAsSubclass<PromiseBasedLoadBalancedCall>()](
ServerMetadataHandle metadata) {
if (self->call_attempt_tracer() != nullptr) {
self->call_attempt_tracer()->RecordReceivedInitialMetadata(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace grpc_core {

RefCountedPtr<GlobalSubchannelPool> GlobalSubchannelPool::instance() {
static GlobalSubchannelPool* p = new GlobalSubchannelPool();
return p->Ref();
return p->RefAsSubclass<GlobalSubchannelPool>();
}

RefCountedPtr<Subchannel> GlobalSubchannelPool::RegisterSubchannel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ void ChildPolicyHandler::ResetBackoffLocked() {

OrphanablePtr<LoadBalancingPolicy> ChildPolicyHandler::CreateChildPolicy(
absl::string_view child_policy_name, const ChannelArgs& args) {
Helper* helper = new Helper(Ref(DEBUG_LOCATION, "Helper"));
Helper* helper =
new Helper(RefAsSubclass<ChildPolicyHandler>(DEBUG_LOCATION, "Helper"));
LoadBalancingPolicy::Args lb_policy_args;
lb_policy_args.work_serializer = work_serializer();
lb_policy_args.channel_control_helper =
Expand Down
31 changes: 15 additions & 16 deletions src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,8 @@ class GrpcLb : public LoadBalancingPolicy {
}
return;
}
WeakRefCountedPtr<SubchannelWrapper> self = WeakRef();
lb_policy_->work_serializer()->Run(
[self = std::move(self)]() {
[self = WeakRefAsSubclass<SubchannelWrapper>()]() {
if (!self->lb_policy_->shutting_down_) {
self->lb_policy_->CacheDeletedSubchannelLocked(
self->wrapped_subchannel());
Expand Down Expand Up @@ -819,8 +818,8 @@ RefCountedPtr<SubchannelInterface> GrpcLb::Helper::CreateSubchannel(
return MakeRefCounted<SubchannelWrapper>(
parent()->channel_control_helper()->CreateSubchannel(
address, per_address_args, args),
parent()->Ref(DEBUG_LOCATION, "SubchannelWrapper"), std::move(lb_token),
std::move(client_stats));
parent()->RefAsSubclass<GrpcLb>(DEBUG_LOCATION, "SubchannelWrapper"),
std::move(lb_token), std::move(client_stats));
}

void GrpcLb::Helper::UpdateState(grpc_connectivity_state state,
Expand Down Expand Up @@ -1558,7 +1557,7 @@ absl::Status GrpcLb::UpdateLocked(UpdateArgs args) {
gpr_log(GPR_INFO, "[grpclb %p] received update", this);
}
const bool is_initial_update = lb_channel_ == nullptr;
config_ = args.config;
config_ = args.config.TakeAsSubclass<GrpcLbConfig>();
GPR_ASSERT(config_ != nullptr);
args_ = std::move(args.args);
// Update fallback address list.
Expand All @@ -1581,8 +1580,8 @@ absl::Status GrpcLb::UpdateLocked(UpdateArgs args) {
lb_fallback_timer_handle_ =
channel_control_helper()->GetEventEngine()->RunAfter(
fallback_at_startup_timeout_,
[self = static_cast<RefCountedPtr<GrpcLb>>(
Ref(DEBUG_LOCATION, "on_fallback_timer"))]() mutable {
[self = RefAsSubclass<GrpcLb>(DEBUG_LOCATION,
"on_fallback_timer")]() mutable {
ApplicationCallbackExecCtx callback_exec_ctx;
ExecCtx exec_ctx;
auto self_ptr = self.get();
Expand All @@ -1597,7 +1596,8 @@ absl::Status GrpcLb::UpdateLocked(UpdateArgs args) {
ClientChannel::GetFromChannel(Channel::FromC(lb_channel_));
GPR_ASSERT(client_channel != nullptr);
// Ref held by callback.
watcher_ = new StateWatcher(Ref(DEBUG_LOCATION, "StateWatcher"));
watcher_ =
new StateWatcher(RefAsSubclass<GrpcLb>(DEBUG_LOCATION, "StateWatcher"));
client_channel->AddConnectivityWatcher(
GRPC_CHANNEL_IDLE,
OrphanablePtr<AsyncConnectivityStateWatcherInterface>(watcher_));
Expand Down Expand Up @@ -1640,11 +1640,10 @@ absl::Status GrpcLb::UpdateBalancerChannelLocked() {
// Set up channelz linkage.
channelz::ChannelNode* child_channelz_node =
grpc_channel_get_channelz_node(lb_channel_);
channelz::ChannelNode* parent_channelz_node =
args_.GetObject<channelz::ChannelNode>();
auto parent_channelz_node = args_.GetObjectRef<channelz::ChannelNode>();
if (child_channelz_node != nullptr && parent_channelz_node != nullptr) {
parent_channelz_node->AddChildChannel(child_channelz_node->uuid());
parent_channelz_node_ = parent_channelz_node->Ref();
parent_channelz_node_ = std::move(parent_channelz_node);
}
}
// Propagate updates to the LB channel (pick_first) through the fake
Expand Down Expand Up @@ -1699,8 +1698,8 @@ void GrpcLb::StartBalancerCallRetryTimerLocked() {
lb_call_retry_timer_handle_ =
channel_control_helper()->GetEventEngine()->RunAfter(
timeout,
[self = static_cast<RefCountedPtr<GrpcLb>>(
Ref(DEBUG_LOCATION, "on_balancer_call_retry_timer"))]() mutable {
[self = RefAsSubclass<GrpcLb>(
DEBUG_LOCATION, "on_balancer_call_retry_timer")]() mutable {
ApplicationCallbackExecCtx callback_exec_ctx;
ExecCtx exec_ctx;
auto self_ptr = self.get();
Expand Down Expand Up @@ -1782,7 +1781,7 @@ OrphanablePtr<LoadBalancingPolicy> GrpcLb::CreateChildPolicyLocked(
lb_policy_args.work_serializer = work_serializer();
lb_policy_args.args = args;
lb_policy_args.channel_control_helper =
std::make_unique<Helper>(Ref(DEBUG_LOCATION, "Helper"));
std::make_unique<Helper>(RefAsSubclass<GrpcLb>(DEBUG_LOCATION, "Helper"));
OrphanablePtr<LoadBalancingPolicy> lb_policy =
MakeOrphanable<ChildPolicyHandler>(std::move(lb_policy_args),
&grpc_lb_glb_trace);
Expand Down Expand Up @@ -1867,8 +1866,8 @@ void GrpcLb::StartSubchannelCacheTimerLocked() {
subchannel_cache_timer_handle_ =
channel_control_helper()->GetEventEngine()->RunAfter(
cached_subchannels_.begin()->first - Timestamp::Now(),
[self = static_cast<RefCountedPtr<GrpcLb>>(
Ref(DEBUG_LOCATION, "OnSubchannelCacheTimer"))]() mutable {
[self = RefAsSubclass<GrpcLb>(DEBUG_LOCATION,
"OnSubchannelCacheTimer")]() mutable {
ApplicationCallbackExecCtx callback_exec_ctx;
ExecCtx exec_ctx;
auto* self_ptr = self.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,8 @@ void HealthProducer::Start(RefCountedPtr<Subchannel> subchannel) {
MutexLock lock(&mu_);
connected_subchannel_ = subchannel_->connected_subchannel();
}
auto connectivity_watcher = MakeRefCounted<ConnectivityWatcher>(WeakRef());
auto connectivity_watcher =
MakeRefCounted<ConnectivityWatcher>(WeakRefAsSubclass<HealthProducer>());
connectivity_watcher_ = connectivity_watcher.get();
subchannel_->WatchConnectivityState(std::move(connectivity_watcher));
}
Expand Down Expand Up @@ -387,7 +388,8 @@ void HealthProducer::AddWatcher(
health_checkers_.emplace(*health_check_service_name, nullptr).first;
auto& health_checker = it->second;
if (health_checker == nullptr) {
health_checker = MakeOrphanable<HealthChecker>(WeakRef(), it->first);
health_checker = MakeOrphanable<HealthChecker>(
WeakRefAsSubclass<HealthProducer>(), it->first);
}
health_checker->AddWatcherLocked(watcher);
}
Expand Down Expand Up @@ -456,7 +458,10 @@ void HealthWatcher::SetSubchannel(Subchannel* subchannel) {
subchannel->GetOrAddDataProducer(
HealthProducer::Type(),
[&](Subchannel::DataProducerInterface** producer) {
if (*producer != nullptr) producer_ = (*producer)->RefIfNonZero();
if (*producer != nullptr) {
producer_ =
(*producer)->RefIfNonZero().TakeAsSubclass<HealthProducer>();
}
if (producer_ == nullptr) {
producer_ = MakeRefCounted<HealthProducer>();
*producer = producer_.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ class OrcaProducer::OrcaStreamEventHandler
void OrcaProducer::Start(RefCountedPtr<Subchannel> subchannel) {
subchannel_ = std::move(subchannel);
connected_subchannel_ = subchannel_->connected_subchannel();
auto connectivity_watcher = MakeRefCounted<ConnectivityWatcher>(WeakRef());
auto connectivity_watcher =
MakeRefCounted<ConnectivityWatcher>(WeakRefAsSubclass<OrcaProducer>());
connectivity_watcher_ = connectivity_watcher.get();
subchannel_->WatchConnectivityState(std::move(connectivity_watcher));
}
Expand Down Expand Up @@ -269,7 +270,8 @@ void OrcaProducer::MaybeStartStreamLocked() {
if (connected_subchannel_ == nullptr) return;
stream_client_ = MakeOrphanable<SubchannelStreamClient>(
connected_subchannel_, subchannel_->pollset_set(),
std::make_unique<OrcaStreamEventHandler>(WeakRef(), report_interval_),
std::make_unique<OrcaStreamEventHandler>(
WeakRefAsSubclass<OrcaProducer>(), report_interval_),
GRPC_TRACE_FLAG_ENABLED(grpc_orca_client_trace) ? "OrcaClient" : nullptr);
}

Expand Down Expand Up @@ -310,7 +312,10 @@ void OrcaWatcher::SetSubchannel(Subchannel* subchannel) {
// If not, create a new one.
subchannel->GetOrAddDataProducer(
OrcaProducer::Type(), [&](Subchannel::DataProducerInterface** producer) {
if (*producer != nullptr) producer_ = (*producer)->RefIfNonZero();
if (*producer != nullptr) {
producer_ =
(*producer)->RefIfNonZero().TakeAsSubclass<OrcaProducer>();
}
if (producer_ == nullptr) {
producer_ = MakeRefCounted<OrcaProducer>();
*producer = producer_.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,8 @@ class OutlierDetectionLb : public LoadBalancingPolicy {
}
return;
}
WeakRefCountedPtr<SubchannelWrapper> self = WeakRef();
work_serializer_->Run(
[self = std::move(self)]() {
[self = WeakRefAsSubclass<SubchannelWrapper>()]() {
if (self->subchannel_state_ != nullptr) {
self->subchannel_state_->RemoveSubchannel(self.get());
}
Expand Down Expand Up @@ -624,7 +623,7 @@ absl::Status OutlierDetectionLb::UpdateLocked(UpdateArgs args) {
}
auto old_config = std::move(config_);
// Update config.
config_ = std::move(args.config);
config_ = args.config.TakeAsSubclass<OutlierDetectionLbConfig>();
// Update outlier detection timer.
if (!config_->CountingEnabled()) {
// No need for timer. Cancel the current timer, if any.
Expand All @@ -639,7 +638,8 @@ absl::Status OutlierDetectionLb::UpdateLocked(UpdateArgs args) {
if (GRPC_TRACE_FLAG_ENABLED(grpc_outlier_detection_lb_trace)) {
gpr_log(GPR_INFO, "[outlier_detection_lb %p] starting timer", this);
}
ejection_timer_ = MakeOrphanable<EjectionTimer>(Ref(), Timestamp::Now());
ejection_timer_ = MakeOrphanable<EjectionTimer>(
RefAsSubclass<OutlierDetectionLb>(), Timestamp::Now());
for (const auto& p : endpoint_state_map_) {
p.second->RotateBucket(); // Reset call counters.
}
Expand All @@ -654,8 +654,8 @@ absl::Status OutlierDetectionLb::UpdateLocked(UpdateArgs args) {
"[outlier_detection_lb %p] interval changed, replacing timer",
this);
}
ejection_timer_ =
MakeOrphanable<EjectionTimer>(Ref(), ejection_timer_->StartTime());
ejection_timer_ = MakeOrphanable<EjectionTimer>(
RefAsSubclass<OutlierDetectionLb>(), ejection_timer_->StartTime());
}
// Update subchannel and endpoint maps.
if (args.addresses.ok()) {
Expand Down Expand Up @@ -783,8 +783,8 @@ OrphanablePtr<LoadBalancingPolicy> OutlierDetectionLb::CreateChildPolicyLocked(
LoadBalancingPolicy::Args lb_policy_args;
lb_policy_args.work_serializer = work_serializer();
lb_policy_args.args = args;
lb_policy_args.channel_control_helper =
std::make_unique<Helper>(Ref(DEBUG_LOCATION, "Helper"));
lb_policy_args.channel_control_helper = std::make_unique<Helper>(
RefAsSubclass<OutlierDetectionLb>(DEBUG_LOCATION, "Helper"));
OrphanablePtr<LoadBalancingPolicy> lb_policy =
MakeOrphanable<ChildPolicyHandler>(std::move(lb_policy_args),
&grpc_outlier_detection_lb_trace);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ void PickFirst::AttemptToConnectUsingLatestUpdateArgsLocked() {
latest_pending_subchannel_list_.get());
}
latest_pending_subchannel_list_ = MakeOrphanable<SubchannelList>(
Ref(), addresses, latest_update_args_.args);
RefAsSubclass<PickFirst>(), addresses, latest_update_args_.args);
// Empty update or no valid subchannels. Put the channel in
// TRANSIENT_FAILURE and request re-resolution.
if (latest_pending_subchannel_list_->size() == 0) {
Expand Down Expand Up @@ -1030,7 +1030,7 @@ void PickFirst::SubchannelList::SubchannelData::ProcessUnselectedReadyLocked() {
gpr_log(GPR_INFO, "[PF %p] starting health watch", p);
}
auto watcher = std::make_unique<HealthWatcher>(
p->Ref(DEBUG_LOCATION, "HealthWatcher"));
p->RefAsSubclass<PickFirst>(DEBUG_LOCATION, "HealthWatcher"));
p->health_watcher_ = watcher.get();
auto health_data_watcher = MakeHealthCheckWatcher(
p->work_serializer(), subchannel_list_->args_, std::move(watcher));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ absl::Status PriorityLb::UpdateLocked(UpdateArgs args) {
gpr_log(GPR_INFO, "[priority_lb %p] received update", this);
}
// Update config.
config_ = std::move(args.config);
config_ = args.config.TakeAsSubclass<PriorityLbConfig>();
// Update args.
args_ = std::move(args.args);
// Update addresses.
Expand Down Expand Up @@ -411,7 +411,8 @@ void PriorityLb::ChoosePriorityLocked() {
// Create child if needed.
if (child == nullptr) {
child = MakeOrphanable<ChildPriority>(
Ref(DEBUG_LOCATION, "ChildPriority"), child_name);
RefAsSubclass<PriorityLb>(DEBUG_LOCATION, "ChildPriority"),
child_name);
auto child_config = config_->children().find(child_name);
GPR_DEBUG_ASSERT(child_config != config_->children().end());
// TODO(roth): If the child reports a non-OK status with the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ RingHash::PickResult RingHash::Picker::Pick(PickArgs args) {
return endpoint_info.picker->Pick(args);
case GRPC_CHANNEL_IDLE:
new EndpointConnectionAttempter(
ring_hash_->Ref(DEBUG_LOCATION, "EndpointConnectionAttempter"),
ring_hash_.Ref(DEBUG_LOCATION, "EndpointConnectionAttempter"),
endpoint_info.endpoint);
ABSL_FALLTHROUGH_INTENDED;
case GRPC_CHANNEL_CONNECTING:
Expand Down Expand Up @@ -677,8 +677,8 @@ absl::Status RingHash::UpdateLocked(UpdateArgs args) {
it->second->UpdateLocked(i);
endpoint_map.emplace(address_set, std::move(it->second));
} else {
endpoint_map.emplace(address_set,
MakeOrphanable<RingHashEndpoint>(Ref(), i));
endpoint_map.emplace(address_set, MakeOrphanable<RingHashEndpoint>(
RefAsSubclass<RingHash>(), i));
}
}
endpoint_map_ = std::move(endpoint_map);
Expand Down Expand Up @@ -779,7 +779,8 @@ void RingHash::UpdateAggregatedConnectivityStateLocked(
// Note that we use our own picker regardless of connectivity state.
channel_control_helper()->UpdateState(
state, status,
MakeRefCounted<Picker>(Ref(DEBUG_LOCATION, "RingHashPicker")));
MakeRefCounted<Picker>(
RefAsSubclass<RingHash>(DEBUG_LOCATION, "RingHashPicker")));
// While the ring_hash policy is reporting TRANSIENT_FAILURE, it will
// not be getting any pick requests from the priority policy.
// However, because the ring_hash policy does not attempt to
Expand Down
Loading

0 comments on commit 3e785d3

Please sign in to comment.