diff --git a/include/linux/udp.h b/include/linux/udp.h index 0807e21cfec95..a69da9c4c1c5c 100644 --- a/include/linux/udp.h +++ b/include/linux/udp.h @@ -209,6 +209,9 @@ static inline void udp_allow_gso(struct sock *sk) #define udp_portaddr_for_each_entry(__sk, list) \ hlist_for_each_entry(__sk, list, __sk_common.skc_portaddr_node) +#define udp_portaddr_for_each_entry_from(__sk) \ + hlist_for_each_entry_from(__sk, __sk_common.skc_portaddr_node) + #define udp_portaddr_for_each_entry_rcu(__sk, list) \ hlist_for_each_entry_rcu(__sk, list, __sk_common.skc_portaddr_node) diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 2742cc7602bb5..cf6285bae4f58 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -93,6 +93,7 @@ #include #include #include +#include #include #include #include @@ -3390,34 +3391,55 @@ struct bpf_iter__udp { int bucket __aligned(8); }; +union bpf_udp_iter_batch_item { + struct sock *sk; + __u64 cookie; +}; + struct bpf_udp_iter_state { struct udp_iter_state state; unsigned int cur_sk; unsigned int end_sk; unsigned int max_sk; - int offset; - struct sock **batch; - bool st_bucket_done; + union bpf_udp_iter_batch_item *batch; }; static int bpf_iter_udp_realloc_batch(struct bpf_udp_iter_state *iter, - unsigned int new_batch_sz); + unsigned int new_batch_sz, int flags); +static struct sock *bpf_iter_udp_resume(struct sock *first_sk, + union bpf_udp_iter_batch_item *cookies, + int n_cookies) +{ + struct sock *sk = NULL; + int i = 0; + + for (; i < n_cookies; i++) { + sk = first_sk; + udp_portaddr_for_each_entry_from(sk) + if (cookies[i].cookie == atomic64_read(&sk->sk_cookie)) + goto done; + } +done: + return sk; +} + static struct sock *bpf_iter_udp_batch(struct seq_file *seq) { struct bpf_udp_iter_state *iter = seq->private; struct udp_iter_state *state = &iter->state; + unsigned int find_cookie, end_cookie = 0; struct net *net = seq_file_net(seq); - int resume_bucket, resume_offset; struct udp_table *udptable; unsigned int batch_sks = 0; - bool resized = false; + int resume_bucket; + int resizes = 0; struct sock *sk; + int err = 0; resume_bucket = state->bucket; - resume_offset = iter->offset; /* The current batch is done, so advance the bucket. */ - if (iter->st_bucket_done) + if (iter->cur_sk == iter->end_sk) state->bucket++; udptable = udp_get_table_seq(seq, net); @@ -3430,62 +3452,89 @@ static struct sock *bpf_iter_udp_batch(struct seq_file *seq) * before releasing the bucket lock. This allows BPF programs that are * called in seq_show to acquire the bucket lock if needed. */ + find_cookie = iter->cur_sk; + end_cookie = iter->end_sk; iter->cur_sk = 0; iter->end_sk = 0; - iter->st_bucket_done = false; batch_sks = 0; for (; state->bucket <= udptable->mask; state->bucket++) { struct udp_hslot *hslot2 = &udptable->hash2[state->bucket].hslot; if (hlist_empty(&hslot2->head)) - continue; + goto next_bucket; - iter->offset = 0; spin_lock_bh(&hslot2->lock); - udp_portaddr_for_each_entry(sk, &hslot2->head) { + sk = hlist_entry_safe(hslot2->head.first, struct sock, + __sk_common.skc_portaddr_node); + /* Resume from the first (in iteration order) unseen socket from + * the last batch that still exists in resume_bucket. Most of + * the time this will just be where the last iteration left off + * in resume_bucket unless that socket disappeared between + * reads. + */ + if (state->bucket == resume_bucket) + sk = bpf_iter_udp_resume(sk, &iter->batch[find_cookie], + end_cookie - find_cookie); +fill_batch: + udp_portaddr_for_each_entry_from(sk) { if (seq_sk_match(seq, sk)) { - /* Resume from the last iterated socket at the - * offset in the bucket before iterator was stopped. - */ - if (state->bucket == resume_bucket && - iter->offset < resume_offset) { - ++iter->offset; - continue; - } if (iter->end_sk < iter->max_sk) { sock_hold(sk); - iter->batch[iter->end_sk++] = sk; + iter->batch[iter->end_sk++].sk = sk; } batch_sks++; } } + + /* Allocate a larger batch and try again. */ + if (unlikely(resizes <= 1 && iter->end_sk && + iter->end_sk != batch_sks)) { + resizes++; + + /* First, try with GFP_USER to maximize the chances of + * grabbing more memory. + */ + if (resizes == 1) { + spin_unlock_bh(&hslot2->lock); + err = bpf_iter_udp_realloc_batch(iter, + batch_sks * 3 / 2, + GFP_USER); + if (err) + return ERR_PTR(err); + /* Start over. */ + goto again; + } + + /* Next, hold onto the lock, so the bucket doesn't + * change while we get the rest of the sockets. + */ + err = bpf_iter_udp_realloc_batch(iter, batch_sks, + GFP_NOWAIT); + if (err) { + spin_unlock_bh(&hslot2->lock); + return ERR_PTR(err); + } + + /* Pick up where we left off. */ + sk = iter->batch[iter->end_sk - 1].sk; + sk = hlist_entry_safe(sk->__sk_common.skc_portaddr_node.next, + struct sock, + __sk_common.skc_portaddr_node); + batch_sks = iter->end_sk; + goto fill_batch; + } + spin_unlock_bh(&hslot2->lock); if (iter->end_sk) break; +next_bucket: + resizes = 0; } - /* All done: no batch made. */ - if (!iter->end_sk) - return NULL; - - if (iter->end_sk == batch_sks) { - /* Batching is done for the current bucket; return the first - * socket to be iterated from the batch. - */ - iter->st_bucket_done = true; - goto done; - } - if (!resized && !bpf_iter_udp_realloc_batch(iter, batch_sks * 3 / 2)) { - resized = true; - /* After allocating a larger batch, retry one more time to grab - * the whole bucket. - */ - goto again; - } -done: - return iter->batch[0]; + WARN_ON_ONCE(iter->end_sk != batch_sks); + return iter->end_sk ? iter->batch[0].sk : NULL; } static void *bpf_iter_udp_seq_next(struct seq_file *seq, void *v, loff_t *pos) @@ -3496,16 +3545,14 @@ static void *bpf_iter_udp_seq_next(struct seq_file *seq, void *v, loff_t *pos) /* Whenever seq_next() is called, the iter->cur_sk is * done with seq_show(), so unref the iter->cur_sk. */ - if (iter->cur_sk < iter->end_sk) { - sock_put(iter->batch[iter->cur_sk++]); - ++iter->offset; - } + if (iter->cur_sk < iter->end_sk) + sock_put(iter->batch[iter->cur_sk++].sk); /* After updating iter->cur_sk, check if there are more sockets * available in the current bucket batch. */ if (iter->cur_sk < iter->end_sk) - sk = iter->batch[iter->cur_sk]; + sk = iter->batch[iter->cur_sk].sk; else /* Prepare a new batch. */ sk = bpf_iter_udp_batch(seq); @@ -3569,8 +3616,19 @@ static int bpf_iter_udp_seq_show(struct seq_file *seq, void *v) static void bpf_iter_udp_put_batch(struct bpf_udp_iter_state *iter) { - while (iter->cur_sk < iter->end_sk) - sock_put(iter->batch[iter->cur_sk++]); + union bpf_udp_iter_batch_item *item; + unsigned int cur_sk = iter->cur_sk; + __u64 cookie; + + /* Remember the cookies of the sockets we haven't seen yet, so we can + * pick up where we left off next time around. + */ + while (cur_sk < iter->end_sk) { + item = &iter->batch[cur_sk++]; + cookie = sock_gen_cookie(item->sk); + sock_put(item->sk); + item->cookie = cookie; + } } static void bpf_iter_udp_seq_stop(struct seq_file *seq, void *v) @@ -3586,10 +3644,8 @@ static void bpf_iter_udp_seq_stop(struct seq_file *seq, void *v) (void)udp_prog_seq_show(prog, &meta, v, 0, 0); } - if (iter->cur_sk < iter->end_sk) { + if (iter->cur_sk < iter->end_sk) bpf_iter_udp_put_batch(iter); - iter->st_bucket_done = false; - } } static const struct seq_operations bpf_iter_udp_seq_ops = { @@ -3831,16 +3887,19 @@ DEFINE_BPF_ITER_FUNC(udp, struct bpf_iter_meta *meta, struct udp_sock *udp_sk, uid_t uid, int bucket) static int bpf_iter_udp_realloc_batch(struct bpf_udp_iter_state *iter, - unsigned int new_batch_sz) + unsigned int new_batch_sz, int flags) { - struct sock **new_batch; + union bpf_udp_iter_batch_item *new_batch; new_batch = kvmalloc_array(new_batch_sz, sizeof(*new_batch), - GFP_USER | __GFP_NOWARN); + flags | __GFP_NOWARN); if (!new_batch) return -ENOMEM; - bpf_iter_udp_put_batch(iter); + if (flags != GFP_NOWAIT) + bpf_iter_udp_put_batch(iter); + + memcpy(new_batch, iter->batch, sizeof(*iter->batch) * iter->end_sk); kvfree(iter->batch); iter->batch = new_batch; iter->max_sk = new_batch_sz; @@ -3859,10 +3918,12 @@ static int bpf_iter_init_udp(void *priv_data, struct bpf_iter_aux_info *aux) if (ret) return ret; - ret = bpf_iter_udp_realloc_batch(iter, INIT_BATCH_SZ); + ret = bpf_iter_udp_realloc_batch(iter, INIT_BATCH_SZ, GFP_USER); if (ret) bpf_iter_fini_seq_net(priv_data); + iter->state.bucket = -1; + return ret; } diff --git a/tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c b/tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c index d56e18b255280..a4517bee34d5b 100644 --- a/tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c +++ b/tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c @@ -7,14 +7,433 @@ #define TEST_NS "sock_iter_batch_netns" +static const int init_batch_size = 16; static const int nr_soreuse = 4; +struct iter_out { + int idx; + __u64 cookie; +} __packed; + +struct sock_count { + __u64 cookie; + int count; +}; + +static int insert(__u64 cookie, struct sock_count counts[], int counts_len) +{ + int insert = -1; + int i = 0; + + for (; i < counts_len; i++) { + if (!counts[i].cookie) { + insert = i; + } else if (counts[i].cookie == cookie) { + insert = i; + break; + } + } + if (insert < 0) + return insert; + + counts[insert].cookie = cookie; + counts[insert].count++; + + return counts[insert].count; +} + +static int read_n(int iter_fd, int n, struct sock_count counts[], + int counts_len) +{ + struct iter_out out; + int nread = 1; + int i = 0; + + for (; nread > 0 && (n < 0 || i < n); i++) { + nread = read(iter_fd, &out, sizeof(out)); + if (!nread || !ASSERT_EQ(nread, sizeof(out), "nread")) + break; + ASSERT_GE(insert(out.cookie, counts, counts_len), 0, "insert"); + } + + ASSERT_TRUE(n < 0 || i == n, "n < 0 || i == n"); + + return i; +} + +static __u64 socket_cookie(int fd) +{ + __u64 cookie; + socklen_t cookie_len = sizeof(cookie); + + if (!ASSERT_OK(getsockopt(fd, SOL_SOCKET, SO_COOKIE, &cookie, + &cookie_len), "getsockopt(SO_COOKIE)")) + return 0; + return cookie; +} + +static bool was_seen(int fd, struct sock_count counts[], int counts_len) +{ + __u64 cookie = socket_cookie(fd); + int i = 0; + + for (; cookie && i < counts_len; i++) + if (cookie == counts[i].cookie) + return true; + + return false; +} + +static int get_seen_socket(int *fds, struct sock_count counts[], int n) +{ + int i = 0; + + for (; i < n; i++) + if (was_seen(fds[i], counts, n)) + return i; + return -1; +} + +static int get_nth_socket(int *fds, int fds_len, struct bpf_link *link, int n) +{ + int i, nread, iter_fd; + int nth_sock_idx = -1; + struct iter_out out; + + iter_fd = bpf_iter_create(bpf_link__fd(link)); + if (!ASSERT_OK_FD(iter_fd, "bpf_iter_create")) + return -1; + + for (; n >= 0; n--) { + nread = read(iter_fd, &out, sizeof(out)); + if (!nread || !ASSERT_GE(nread, 1, "nread")) + goto done; + } + + for (i = 0; i < fds_len && nth_sock_idx < 0; i++) + if (fds[i] >= 0 && socket_cookie(fds[i]) == out.cookie) + nth_sock_idx = i; +done: + close(iter_fd); + return nth_sock_idx; +} + +static int get_seen_count(int fd, struct sock_count counts[], int n) +{ + __u64 cookie = socket_cookie(fd); + int count = 0; + int i = 0; + + for (; cookie && !count && i < n; i++) + if (cookie == counts[i].cookie) + count = counts[i].count; + + return count; +} + +static void check_n_were_seen_once(int *fds, int fds_len, int n, + struct sock_count counts[], int counts_len) +{ + int seen_once = 0; + int seen_cnt; + int i = 0; + + for (; i < fds_len; i++) { + /* Skip any sockets that were closed or that weren't seen + * exactly once. + */ + if (fds[i] < 0) + continue; + seen_cnt = get_seen_count(fds[i], counts, counts_len); + if (seen_cnt && ASSERT_EQ(seen_cnt, 1, "seen_cnt")) + seen_once++; + } + + ASSERT_EQ(seen_once, n, "seen_once"); +} + +static void remove_seen(int family, int sock_type, const char *addr, __u16 port, + int *socks, int socks_len, struct sock_count *counts, + int counts_len, struct bpf_link *link, int iter_fd) +{ + int close_idx; + + /* Iterate through the first socks_len - 1 sockets. */ + read_n(iter_fd, socks_len - 1, counts, counts_len); + + /* Make sure we saw socks_len - 1 sockets exactly once. */ + check_n_were_seen_once(socks, socks_len, socks_len - 1, counts, + counts_len); + + /* Close a socket we've already seen to remove it from the bucket. */ + close_idx = get_seen_socket(socks, counts, counts_len); + if (!ASSERT_GE(close_idx, 0, "close_idx")) + return; + close(socks[close_idx]); + socks[close_idx] = -1; + + /* Iterate through the rest of the sockets. */ + read_n(iter_fd, -1, counts, counts_len); + + /* Make sure the last socket wasn't skipped and that there were no + * repeats. + */ + check_n_were_seen_once(socks, socks_len, socks_len - 1, counts, + counts_len); +} + +static void remove_unseen(int family, int sock_type, const char *addr, + __u16 port, int *socks, int socks_len, + struct sock_count *counts, int counts_len, + struct bpf_link *link, int iter_fd) +{ + int close_idx; + + /* Iterate through the first socket. */ + read_n(iter_fd, 1, counts, counts_len); + + /* Make sure we saw a socket from fds. */ + check_n_were_seen_once(socks, socks_len, 1, counts, counts_len); + + /* Close what would be the next socket in the bucket to exercise the + * condition where we need to skip past the first cookie we remembered. + */ + close_idx = get_nth_socket(socks, socks_len, link, 1); + if (!ASSERT_GE(close_idx, 0, "close_idx")) + return; + close(socks[close_idx]); + socks[close_idx] = -1; + + /* Iterate through the rest of the sockets. */ + read_n(iter_fd, -1, counts, counts_len); + + /* Make sure the remaining sockets were seen exactly once and that we + * didn't repeat the socket that was already seen. + */ + check_n_were_seen_once(socks, socks_len, socks_len - 1, counts, + counts_len); +} + +static void remove_all(int family, int sock_type, const char *addr, + __u16 port, int *socks, int socks_len, + struct sock_count *counts, int counts_len, + struct bpf_link *link, int iter_fd) +{ + int close_idx, i; + + /* Iterate through the first socket. */ + read_n(iter_fd, 1, counts, counts_len); + + /* Make sure we saw a socket from fds. */ + check_n_were_seen_once(socks, socks_len, 1, counts, counts_len); + + /* Close all remaining sockets to exhaust the list of saved cookies and + * exit without putting any sockets into the batch on the next read. + */ + for (i = 0; i < socks_len - 1; i++) { + close_idx = get_nth_socket(socks, socks_len, link, 1); + if (!ASSERT_GE(close_idx, 0, "close_idx")) + return; + close(socks[close_idx]); + socks[close_idx] = -1; + } + + /* Make sure there are no more sockets returned */ + ASSERT_EQ(read_n(iter_fd, -1, counts, counts_len), 0, "read_n"); +} + +static void add_some(int family, int sock_type, const char *addr, __u16 port, + int *socks, int socks_len, struct sock_count *counts, + int counts_len, struct bpf_link *link, int iter_fd) +{ + int *new_socks = NULL; + + /* Iterate through the first socks_len - 1 sockets. */ + read_n(iter_fd, socks_len - 1, counts, counts_len); + + /* Make sure we saw socks_len - 1 sockets exactly once. */ + check_n_were_seen_once(socks, socks_len, socks_len - 1, counts, + counts_len); + + /* Double the number of sockets in the bucket. */ + new_socks = start_reuseport_server(family, sock_type, addr, port, 0, + socks_len); + if (!ASSERT_OK_PTR(new_socks, "start_reuseport_server")) + goto done; + + /* Iterate through the rest of the sockets. */ + read_n(iter_fd, -1, counts, counts_len); + + /* Make sure each of the original sockets was seen exactly once. */ + check_n_were_seen_once(socks, socks_len, socks_len, counts, + counts_len); +done: + free_fds(new_socks, socks_len); +} + +static void force_realloc(int family, int sock_type, const char *addr, + __u16 port, int *socks, int socks_len, + struct sock_count *counts, int counts_len, + struct bpf_link *link, int iter_fd) +{ + int *new_socks = NULL; + + /* Iterate through the first socket just to initialize the batch. */ + read_n(iter_fd, 1, counts, counts_len); + + /* Double the number of sockets in the bucket to force a realloc on the + * next read. + */ + new_socks = start_reuseport_server(family, sock_type, addr, port, 0, + socks_len); + if (!ASSERT_OK_PTR(new_socks, "start_reuseport_server")) + goto done; + + /* Iterate through the rest of the sockets. */ + read_n(iter_fd, -1, counts, counts_len); + + /* Make sure each socket from the first set was seen exactly once. */ + check_n_were_seen_once(socks, socks_len, socks_len, counts, + counts_len); +done: + free_fds(new_socks, socks_len); +} + +struct test_case { + void (*test)(int family, int sock_type, const char *addr, __u16 port, + int *socks, int socks_len, struct sock_count *counts, + int counts_len, struct bpf_link *link, int iter_fd); + const char *description; + int init_socks; + int max_socks; + int sock_type; + int family; +}; + +static struct test_case resume_tests[] = { + { + .description = "udp: resume after removing a seen socket", + .init_socks = nr_soreuse, + .max_socks = nr_soreuse, + .sock_type = SOCK_DGRAM, + .family = AF_INET6, + .test = remove_seen, + }, + { + .description = "udp: resume after removing one unseen socket", + .init_socks = nr_soreuse, + .max_socks = nr_soreuse, + .sock_type = SOCK_DGRAM, + .family = AF_INET6, + .test = remove_unseen, + }, + { + .description = "udp: resume after removing all unseen sockets", + .init_socks = nr_soreuse, + .max_socks = nr_soreuse, + .sock_type = SOCK_DGRAM, + .family = AF_INET6, + .test = remove_all, + }, + { + .description = "udp: resume after adding a few sockets", + .init_socks = nr_soreuse, + .max_socks = nr_soreuse, + .sock_type = SOCK_DGRAM, + /* Use AF_INET so that new sockets are added to the head of the + * bucket's list. + */ + .family = AF_INET, + .test = add_some, + }, + { + .description = "udp: force a realloc to occur", + .init_socks = init_batch_size, + .max_socks = init_batch_size * 2, + .sock_type = SOCK_DGRAM, + /* Use AF_INET6 so that new sockets are added to the tail of the + * bucket's list, needing to be added to the next batch to force + * a realloc. + */ + .family = AF_INET6, + .test = force_realloc, + }, +}; + +static void do_resume_test(struct test_case *tc) +{ + struct sock_iter_batch *skel = NULL; + static const __u16 port = 10001; + struct bpf_link *link = NULL; + struct sock_count *counts; + int err, iter_fd = -1; + const char *addr; + int *fds = NULL; + int local_port; + + counts = calloc(tc->max_socks, sizeof(*counts)); + if (!ASSERT_OK_PTR(counts, "counts")) + goto done; + skel = sock_iter_batch__open(); + if (!ASSERT_OK_PTR(skel, "sock_iter_batch__open")) + goto done; + + /* Prepare a bucket of sockets in the kernel hashtable */ + addr = tc->family == AF_INET6 ? "::1" : "127.0.0.1"; + fds = start_reuseport_server(tc->family, tc->sock_type, addr, port, 0, + tc->init_socks); + if (!ASSERT_OK_PTR(fds, "start_reuseport_server")) + goto done; + local_port = get_socket_local_port(*fds); + if (!ASSERT_GE(local_port, 0, "get_socket_local_port")) + goto done; + skel->rodata->ports[0] = ntohs(local_port); + skel->rodata->sf = tc->family; + + err = sock_iter_batch__load(skel); + if (!ASSERT_OK(err, "sock_iter_batch__load")) + goto done; + + link = bpf_program__attach_iter(tc->sock_type == SOCK_STREAM ? + skel->progs.iter_tcp_soreuse : + skel->progs.iter_udp_soreuse, + NULL); + if (!ASSERT_OK_PTR(link, "bpf_program__attach_iter")) + goto done; + + iter_fd = bpf_iter_create(bpf_link__fd(link)); + if (!ASSERT_OK_FD(iter_fd, "bpf_iter_create")) + goto done; + + tc->test(tc->family, tc->sock_type, addr, port, fds, tc->init_socks, + counts, tc->max_socks, link, iter_fd); +done: + free(counts); + free_fds(fds, tc->init_socks); + if (iter_fd >= 0) + close(iter_fd); + bpf_link__destroy(link); + sock_iter_batch__destroy(skel); +} + +static void do_resume_tests(void) +{ + int i; + + for (i = 0; i < ARRAY_SIZE(resume_tests); i++) { + if (test__start_subtest(resume_tests[i].description)) { + do_resume_test(&resume_tests[i]); + } + } +} + static void do_test(int sock_type, bool onebyone) { int err, i, nread, to_read, total_read, iter_fd = -1; - int first_idx, second_idx, indices[nr_soreuse]; + struct iter_out outputs[nr_soreuse]; struct bpf_link *link = NULL; struct sock_iter_batch *skel; + int first_idx, second_idx; int *fds[2] = {}; skel = sock_iter_batch__open(); @@ -34,6 +453,7 @@ static void do_test(int sock_type, bool onebyone) goto done; skel->rodata->ports[i] = ntohs(local_port); } + skel->rodata->sf = AF_INET6; err = sock_iter_batch__load(skel); if (!ASSERT_OK(err, "sock_iter_batch__load")) @@ -55,38 +475,38 @@ static void do_test(int sock_type, bool onebyone) * from a bucket and leave one socket out from * that bucket on purpose. */ - to_read = (nr_soreuse - 1) * sizeof(*indices); + to_read = (nr_soreuse - 1) * sizeof(*outputs); total_read = 0; first_idx = -1; do { - nread = read(iter_fd, indices, onebyone ? sizeof(*indices) : to_read); - if (nread <= 0 || nread % sizeof(*indices)) + nread = read(iter_fd, outputs, onebyone ? sizeof(*outputs) : to_read); + if (nread <= 0 || nread % sizeof(*outputs)) break; total_read += nread; if (first_idx == -1) - first_idx = indices[0]; - for (i = 0; i < nread / sizeof(*indices); i++) - ASSERT_EQ(indices[i], first_idx, "first_idx"); + first_idx = outputs[0].idx; + for (i = 0; i < nread / sizeof(*outputs); i++) + ASSERT_EQ(outputs[i].idx, first_idx, "first_idx"); } while (total_read < to_read); - ASSERT_EQ(nread, onebyone ? sizeof(*indices) : to_read, "nread"); + ASSERT_EQ(nread, onebyone ? sizeof(*outputs) : to_read, "nread"); ASSERT_EQ(total_read, to_read, "total_read"); free_fds(fds[first_idx], nr_soreuse); fds[first_idx] = NULL; /* Read the "whole" second bucket */ - to_read = nr_soreuse * sizeof(*indices); + to_read = nr_soreuse * sizeof(*outputs); total_read = 0; second_idx = !first_idx; do { - nread = read(iter_fd, indices, onebyone ? sizeof(*indices) : to_read); - if (nread <= 0 || nread % sizeof(*indices)) + nread = read(iter_fd, outputs, onebyone ? sizeof(*outputs) : to_read); + if (nread <= 0 || nread % sizeof(*outputs)) break; total_read += nread; - for (i = 0; i < nread / sizeof(*indices); i++) - ASSERT_EQ(indices[i], second_idx, "second_idx"); + for (i = 0; i < nread / sizeof(*outputs); i++) + ASSERT_EQ(outputs[i].idx, second_idx, "second_idx"); } while (total_read <= to_read); ASSERT_EQ(nread, 0, "nread"); /* Both so_reuseport ports should be in different buckets, so @@ -128,6 +548,7 @@ void test_sock_iter_batch(void) do_test(SOCK_DGRAM, true); do_test(SOCK_DGRAM, false); } + do_resume_tests(); close_netns(nstoken); done: diff --git a/tools/testing/selftests/bpf/progs/bpf_tracing_net.h b/tools/testing/selftests/bpf/progs/bpf_tracing_net.h index 659694162739e..17db400f0e0d9 100644 --- a/tools/testing/selftests/bpf/progs/bpf_tracing_net.h +++ b/tools/testing/selftests/bpf/progs/bpf_tracing_net.h @@ -128,6 +128,7 @@ #define sk_refcnt __sk_common.skc_refcnt #define sk_state __sk_common.skc_state #define sk_net __sk_common.skc_net +#define sk_rcv_saddr __sk_common.skc_rcv_saddr #define sk_v6_daddr __sk_common.skc_v6_daddr #define sk_v6_rcv_saddr __sk_common.skc_v6_rcv_saddr #define sk_flags __sk_common.skc_flags diff --git a/tools/testing/selftests/bpf/progs/sock_iter_batch.c b/tools/testing/selftests/bpf/progs/sock_iter_batch.c index 96531b0d9d55b..8f483337e103c 100644 --- a/tools/testing/selftests/bpf/progs/sock_iter_batch.c +++ b/tools/testing/selftests/bpf/progs/sock_iter_batch.c @@ -17,6 +17,12 @@ static bool ipv6_addr_loopback(const struct in6_addr *a) a->s6_addr32[2] | (a->s6_addr32[3] ^ bpf_htonl(1))) == 0; } +static bool ipv4_addr_loopback(__be32 a) +{ + return a == bpf_ntohl(0x7f000001); +} + +volatile const unsigned int sf; volatile const __u16 ports[2]; unsigned int bucket[2]; @@ -26,16 +32,20 @@ int iter_tcp_soreuse(struct bpf_iter__tcp *ctx) struct sock *sk = (struct sock *)ctx->sk_common; struct inet_hashinfo *hinfo; unsigned int hash; + __u64 sock_cookie; struct net *net; int idx; if (!sk) return 0; + sock_cookie = bpf_get_socket_cookie(sk); sk = bpf_core_cast(sk, struct sock); - if (sk->sk_family != AF_INET6 || + if (sk->sk_family != sf || sk->sk_state != TCP_LISTEN || - !ipv6_addr_loopback(&sk->sk_v6_rcv_saddr)) + sk->sk_family == AF_INET6 ? + !ipv6_addr_loopback(&sk->sk_v6_rcv_saddr) : + !ipv4_addr_loopback(sk->sk_rcv_saddr)) return 0; if (sk->sk_num == ports[0]) @@ -52,6 +62,7 @@ int iter_tcp_soreuse(struct bpf_iter__tcp *ctx) hinfo = net->ipv4.tcp_death_row.hashinfo; bucket[idx] = hash & hinfo->lhash2_mask; bpf_seq_write(ctx->meta->seq, &idx, sizeof(idx)); + bpf_seq_write(ctx->meta->seq, &sock_cookie, sizeof(sock_cookie)); return 0; } @@ -63,14 +74,18 @@ int iter_udp_soreuse(struct bpf_iter__udp *ctx) { struct sock *sk = (struct sock *)ctx->udp_sk; struct udp_table *udptable; + __u64 sock_cookie; int idx; if (!sk) return 0; + sock_cookie = bpf_get_socket_cookie(sk); sk = bpf_core_cast(sk, struct sock); - if (sk->sk_family != AF_INET6 || - !ipv6_addr_loopback(&sk->sk_v6_rcv_saddr)) + if (sk->sk_family != sf || + sk->sk_family == AF_INET6 ? + !ipv6_addr_loopback(&sk->sk_v6_rcv_saddr) : + !ipv4_addr_loopback(sk->sk_rcv_saddr)) return 0; if (sk->sk_num == ports[0]) @@ -84,6 +99,7 @@ int iter_udp_soreuse(struct bpf_iter__udp *ctx) udptable = sk->sk_net.net->ipv4.udp_table; bucket[idx] = udp_sk(sk)->udp_portaddr_hash & udptable->mask; bpf_seq_write(ctx->meta->seq, &idx, sizeof(idx)); + bpf_seq_write(ctx->meta->seq, &sock_cookie, sizeof(sock_cookie)); return 0; }