diff mbox series

[v4,bpf-next,1/4] bpf: Implement batching in UDP iterator

Message ID 20230323200633.3175753-2-aditi.ghag@isovalent.com (mailing list archive)
State Changes Requested
Delegated to: BPF
Headers show
Series bpf-nex: Add socket destroy capability | expand

Checks

Context Check Description
bpf/vmtest-bpf-next-PR success PR summary
bpf/vmtest-bpf-next-VM_Test-1 success Logs for ShellCheck
bpf/vmtest-bpf-next-VM_Test-2 success Logs for build for aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-3 success Logs for build for aarch64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-4 success Logs for build for s390x with gcc
bpf/vmtest-bpf-next-VM_Test-5 success Logs for build for x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-6 success Logs for build for x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-7 success Logs for set-matrix
bpf/vmtest-bpf-next-VM_Test-8 success Logs for test_maps on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-9 success Logs for test_maps on aarch64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-10 success Logs for test_maps on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-11 success Logs for test_maps on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-12 success Logs for test_maps on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-13 success Logs for test_progs on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-14 success Logs for test_progs on aarch64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-15 success Logs for test_progs on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-16 success Logs for test_progs on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-17 success Logs for test_progs on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-18 success Logs for test_progs_no_alu32 on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-19 success Logs for test_progs_no_alu32 on aarch64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-20 success Logs for test_progs_no_alu32 on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-21 success Logs for test_progs_no_alu32 on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-22 success Logs for test_progs_no_alu32 on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-23 success Logs for test_progs_no_alu32_parallel on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-24 success Logs for test_progs_no_alu32_parallel on aarch64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-25 success Logs for test_progs_no_alu32_parallel on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-26 success Logs for test_progs_no_alu32_parallel on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-27 success Logs for test_progs_parallel on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-28 success Logs for test_progs_parallel on aarch64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-29 success Logs for test_progs_parallel on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-30 success Logs for test_progs_parallel on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-31 success Logs for test_verifier on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-32 success Logs for test_verifier on aarch64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-33 success Logs for test_verifier on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-34 success Logs for test_verifier on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-35 success Logs for test_verifier on x86_64 with llvm-16

Commit Message

Aditi Ghag March 23, 2023, 8:06 p.m. UTC
Batch UDP sockets from BPF iterator that allows for overlapping locking
semantics in BPF/kernel helpers executed in BPF programs.  This facilitates
BPF socket destroy kfunc (introduced by follow-up patches) to execute from
BPF iterator programs.

Previously, BPF iterators acquired the sock lock and sockets hash table
bucket lock while executing BPF programs. This prevented BPF helpers that
again acquire these locks to be executed from BPF iterators.  With the
batching approach, we acquire a bucket lock, batch all the bucket sockets,
and then release the bucket lock. This enables BPF or kernel helpers to
skip sock locking when invoked in the supported BPF contexts.

The batching logic is similar to the logic implemented in TCP iterator:
https://lore.kernel.org/bpf/20210701200613.1036157-1-kafai@fb.com/.

Suggested-by: Martin KaFai Lau <martin.lau@kernel.org>
Signed-off-by: Aditi Ghag <aditi.ghag@isovalent.com>
---
 include/net/udp.h |   1 +
 net/ipv4/udp.c    | 255 ++++++++++++++++++++++++++++++++++++++++++++--
 2 files changed, 247 insertions(+), 9 deletions(-)

Comments

Stanislav Fomichev March 24, 2023, 9:56 p.m. UTC | #1
On 03/23, Aditi Ghag wrote:
> Batch UDP sockets from BPF iterator that allows for overlapping locking
> semantics in BPF/kernel helpers executed in BPF programs.  This  
> facilitates
> BPF socket destroy kfunc (introduced by follow-up patches) to execute from
> BPF iterator programs.

> Previously, BPF iterators acquired the sock lock and sockets hash table
> bucket lock while executing BPF programs. This prevented BPF helpers that
> again acquire these locks to be executed from BPF iterators.  With the
> batching approach, we acquire a bucket lock, batch all the bucket sockets,
> and then release the bucket lock. This enables BPF or kernel helpers to
> skip sock locking when invoked in the supported BPF contexts.

> The batching logic is similar to the logic implemented in TCP iterator:
> https://lore.kernel.org/bpf/20210701200613.1036157-1-kafai@fb.com/.

> Suggested-by: Martin KaFai Lau <martin.lau@kernel.org>
> Signed-off-by: Aditi Ghag <aditi.ghag@isovalent.com>
> ---
>   include/net/udp.h |   1 +
>   net/ipv4/udp.c    | 255 ++++++++++++++++++++++++++++++++++++++++++++--
>   2 files changed, 247 insertions(+), 9 deletions(-)

> diff --git a/include/net/udp.h b/include/net/udp.h
> index de4b528522bb..d2999447d3f2 100644
> --- a/include/net/udp.h
> +++ b/include/net/udp.h
> @@ -437,6 +437,7 @@ struct udp_seq_afinfo {
>   struct udp_iter_state {
>   	struct seq_net_private  p;
>   	int			bucket;
> +	int			offset;
>   	struct udp_seq_afinfo	*bpf_seq_afinfo;
>   };

> diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
> index c605d171eb2d..58c620243e47 100644
> --- a/net/ipv4/udp.c
> +++ b/net/ipv4/udp.c
> @@ -3152,6 +3152,171 @@ struct bpf_iter__udp {
>   	int bucket __aligned(8);
>   };

> +struct bpf_udp_iter_state {
> +	struct udp_iter_state state;
> +	unsigned int cur_sk;
> +	unsigned int end_sk;
> +	unsigned int max_sk;
> +	struct sock **batch;
> +	bool st_bucket_done;
> +};
> +
> +static unsigned short seq_file_family(const struct seq_file *seq);
> +static int bpf_iter_udp_realloc_batch(struct bpf_udp_iter_state *iter,
> +				      unsigned int new_batch_sz);
> +
> +static inline bool seq_sk_match(struct seq_file *seq, const struct sock  
> *sk)
> +{
> +	unsigned short family = seq_file_family(seq);
> +
> +	/* AF_UNSPEC is used as a match all */
> +	return ((family == AF_UNSPEC || family == sk->sk_family) &&
> +		net_eq(sock_net(sk), seq_file_net(seq)));
> +}
> +
> +static struct sock *bpf_iter_udp_batch(struct seq_file *seq)
> +{
> +	struct bpf_udp_iter_state *iter = seq->private;
> +	struct udp_iter_state *state = &iter->state;
> +	struct net *net = seq_file_net(seq);
> +	struct udp_seq_afinfo *afinfo = state->bpf_seq_afinfo;
> +	struct udp_table *udptable;
> +	struct sock *first_sk = NULL;
> +	struct sock *sk;
> +	unsigned int bucket_sks = 0;
> +	bool resized = false;
> +	int offset = 0;
> +	int new_offset;
> +
> +	/* The current batch is done, so advance the bucket. */
> +	if (iter->st_bucket_done) {
> +		state->bucket++;
> +		state->offset = 0;
> +	}
> +
> +	udptable = udp_get_table_afinfo(afinfo, net);
> +
> +	if (state->bucket > udptable->mask) {
> +		state->bucket = 0;
> +		state->offset = 0;
> +		return NULL;
> +	}
> +
> +again:
> +	/* New batch for the next bucket.
> +	 * Iterate over the hash table to find a bucket with sockets matching
> +	 * the iterator attributes, and return the first matching socket from
> +	 * the bucket. The remaining matched sockets from the bucket are batched
> +	 * before releasing the bucket lock. This allows BPF programs that are
> +	 * called in seq_show to acquire the bucket lock if needed.
> +	 */
> +	iter->cur_sk = 0;
> +	iter->end_sk = 0;
> +	iter->st_bucket_done = false;
> +	first_sk = NULL;
> +	bucket_sks = 0;
> +	offset = state->offset;
> +	new_offset = offset;
> +
> +	for (; state->bucket <= udptable->mask; state->bucket++) {
> +		struct udp_hslot *hslot = &udptable->hash[state->bucket];
> +
> +		if (hlist_empty(&hslot->head)) {
> +			offset = 0;
> +			continue;
> +		}
> +
> +		spin_lock_bh(&hslot->lock);
> +		/* Resume from the last saved position in a bucket before
> +		 * iterator was stopped.
> +		 */
> +		while (offset-- > 0) {
> +			sk_for_each(sk, &hslot->head)
> +				continue;
> +		}
> +		sk_for_each(sk, &hslot->head) {
> +			if (seq_sk_match(seq, sk)) {
> +				if (!first_sk)
> +					first_sk = sk;
> +				if (iter->end_sk < iter->max_sk) {
> +					sock_hold(sk);
> +					iter->batch[iter->end_sk++] = sk;
> +				}
> +				bucket_sks++;
> +			}
> +			new_offset++;
> +		}
> +		spin_unlock_bh(&hslot->lock);
> +
> +		if (first_sk)
> +			break;
> +
> +		/* Reset the current bucket's offset before moving to the next bucket.  
> */
> +		offset = 0;
> +		new_offset = 0;
> +	}
> +
> +	/* All done: no batch made. */
> +	if (!first_sk)
> +		goto ret;
> +
> +	if (iter->end_sk == bucket_sks) {
> +		/* Batching is done for the current bucket; return the first
> +		 * socket to be iterated from the batch.
> +		 */
> +		iter->st_bucket_done = true;
> +		goto ret;
> +	}
> +	if (!resized && !bpf_iter_udp_realloc_batch(iter, bucket_sks * 3 / 2)) {
> +		resized = true;
> +		/* Go back to the previous bucket to resize its batch. */
> +		state->bucket--;
> +		goto again;
> +	}
> +ret:
> +	state->offset = new_offset;
> +	return first_sk;
> +}
> +
> +static void *bpf_iter_udp_seq_next(struct seq_file *seq, void *v, loff_t  
> *pos)
> +{
> +	struct bpf_udp_iter_state *iter = seq->private;
> +	struct udp_iter_state *state = &iter->state;
> +	struct sock *sk;
> +
> +	/* Whenever seq_next() is called, the iter->cur_sk is
> +	 * done with seq_show(), so unref the iter->cur_sk.
> +	 */
> +	if (iter->cur_sk < iter->end_sk) {
> +		sock_put(iter->batch[iter->cur_sk++]);
> +		++state->offset;
> +	}
> +
> +	/* After updating iter->cur_sk, check if there are more sockets
> +	 * available in the current bucket batch.
> +	 */
> +	if (iter->cur_sk < iter->end_sk) {
> +		sk = iter->batch[iter->cur_sk];
> +	} else {
> +		// Prepare a new batch.
> +		sk = bpf_iter_udp_batch(seq);
> +	}
> +
> +	++*pos;
> +	return sk;
> +}
> +
> +static void *bpf_iter_udp_seq_start(struct seq_file *seq, loff_t *pos)
> +{
> +	/* bpf iter does not support lseek, so it always
> +	 * continue from where it was stop()-ped.
> +	 */
> +	if (*pos)
> +		return bpf_iter_udp_batch(seq);
> +
> +	return SEQ_START_TOKEN;
> +}
> +
>   static int udp_prog_seq_show(struct bpf_prog *prog, struct bpf_iter_meta  
> *meta,
>   			     struct udp_sock *udp_sk, uid_t uid, int bucket)
>   {
> @@ -3172,18 +3337,38 @@ static int bpf_iter_udp_seq_show(struct seq_file  
> *seq, void *v)
>   	struct bpf_prog *prog;
>   	struct sock *sk = v;
>   	uid_t uid;
> +	bool slow;
> +	int rc;

>   	if (v == SEQ_START_TOKEN)
>   		return 0;


[..]

> +	slow = lock_sock_fast(sk);
> +
> +	if (unlikely(sk_unhashed(sk))) {
> +		rc = SEQ_SKIP;
> +		goto unlock;
> +	}
> +

Should we use non-fast version here for consistency with tcp?


>   	uid = from_kuid_munged(seq_user_ns(seq), sock_i_uid(sk));
>   	meta.seq = seq;
>   	prog = bpf_iter_get_info(&meta, false);
> -	return udp_prog_seq_show(prog, &meta, v, uid, state->bucket);
> +	rc = udp_prog_seq_show(prog, &meta, v, uid, state->bucket);
> +
> +unlock:
> +	unlock_sock_fast(sk, slow);
> +	return rc;
> +}
> +
> +static void bpf_iter_udp_unref_batch(struct bpf_udp_iter_state *iter)
> +{
> +	while (iter->cur_sk < iter->end_sk)
> +		sock_put(iter->batch[iter->cur_sk++]);
>   }

>   static void bpf_iter_udp_seq_stop(struct seq_file *seq, void *v)
>   {
> +	struct bpf_udp_iter_state *iter = seq->private;
>   	struct bpf_iter_meta meta;
>   	struct bpf_prog *prog;

> @@ -3194,15 +3379,31 @@ static void bpf_iter_udp_seq_stop(struct seq_file  
> *seq, void *v)
>   			(void)udp_prog_seq_show(prog, &meta, v, 0, 0);
>   	}

> -	udp_seq_stop(seq, v);
> +	if (iter->cur_sk < iter->end_sk) {
> +		bpf_iter_udp_unref_batch(iter);
> +		iter->st_bucket_done = false;
> +	}
>   }

>   static const struct seq_operations bpf_iter_udp_seq_ops = {
> -	.start		= udp_seq_start,
> -	.next		= udp_seq_next,
> +	.start		= bpf_iter_udp_seq_start,
> +	.next		= bpf_iter_udp_seq_next,
>   	.stop		= bpf_iter_udp_seq_stop,
>   	.show		= bpf_iter_udp_seq_show,
>   };
> +
> +static unsigned short seq_file_family(const struct seq_file *seq)
> +{
> +	const struct udp_seq_afinfo *afinfo;
> +
> +	/* BPF iterator: bpf programs to filter sockets. */
> +	if (seq->op == &bpf_iter_udp_seq_ops)
> +		return AF_UNSPEC;
> +
> +	/* Proc fs iterator */
> +	afinfo = pde_data(file_inode(seq->file));
> +	return afinfo->family;
> +}
>   #endif

>   const struct seq_operations udp_seq_ops = {
> @@ -3413,9 +3614,30 @@ static struct pernet_operations __net_initdata  
> udp_sysctl_ops = {
>   DEFINE_BPF_ITER_FUNC(udp, struct bpf_iter_meta *meta,
>   		     struct udp_sock *udp_sk, uid_t uid, int bucket)

> +static int bpf_iter_udp_realloc_batch(struct bpf_udp_iter_state *iter,
> +				      unsigned int new_batch_sz)
> +{
> +	struct sock **new_batch;
> +
> +	new_batch = kvmalloc_array(new_batch_sz, sizeof(*new_batch),
> +				   GFP_USER | __GFP_NOWARN);
> +	if (!new_batch)
> +		return -ENOMEM;
> +
> +	bpf_iter_udp_unref_batch(iter);
> +	kvfree(iter->batch);
> +	iter->batch = new_batch;
> +	iter->max_sk = new_batch_sz;
> +
> +	return 0;
> +}
> +
> +#define INIT_BATCH_SZ 16
> +
>   static int bpf_iter_init_udp(void *priv_data, struct bpf_iter_aux_info  
> *aux)
>   {
> -	struct udp_iter_state *st = priv_data;
> +	struct bpf_udp_iter_state *iter = priv_data;
> +	struct udp_iter_state *st = &iter->state;
>   	struct udp_seq_afinfo *afinfo;
>   	int ret;

> @@ -3427,24 +3649,39 @@ static int bpf_iter_init_udp(void *priv_data,  
> struct bpf_iter_aux_info *aux)
>   	afinfo->udp_table = NULL;
>   	st->bpf_seq_afinfo = afinfo;
>   	ret = bpf_iter_init_seq_net(priv_data, aux);
> -	if (ret)
> +	if (ret) {
>   		kfree(afinfo);
> +		return ret;
> +	}
> +	ret = bpf_iter_udp_realloc_batch(iter, INIT_BATCH_SZ);
> +	if (ret) {
> +		bpf_iter_fini_seq_net(priv_data);
> +		return ret;
> +	}
> +	iter->cur_sk = 0;
> +	iter->end_sk = 0;
> +	iter->st_bucket_done = false;
> +	st->bucket = 0;
> +	st->offset = 0;
> +
>   	return ret;
>   }

>   static void bpf_iter_fini_udp(void *priv_data)
>   {
> -	struct udp_iter_state *st = priv_data;
> +	struct bpf_udp_iter_state *iter = priv_data;
> +	struct udp_iter_state *st = &iter->state;

> -	kfree(st->bpf_seq_afinfo);
>   	bpf_iter_fini_seq_net(priv_data);
> +	kfree(st->bpf_seq_afinfo);
> +	kvfree(iter->batch);
>   }

>   static const struct bpf_iter_seq_info udp_seq_info = {
>   	.seq_ops		= &bpf_iter_udp_seq_ops,
>   	.init_seq_private	= bpf_iter_init_udp,
>   	.fini_seq_private	= bpf_iter_fini_udp,
> -	.seq_priv_size		= sizeof(struct udp_iter_state),
> +	.seq_priv_size		= sizeof(struct bpf_udp_iter_state),
>   };

>   static struct bpf_iter_reg udp_reg_info = {
> --
> 2.34.1
Aditi Ghag March 27, 2023, 3:52 p.m. UTC | #2
> On Mar 24, 2023, at 2:56 PM, Stanislav Fomichev <sdf@google.com> wrote:
> 
> On 03/23, Aditi Ghag wrote:
>> Batch UDP sockets from BPF iterator that allows for overlapping locking
>> semantics in BPF/kernel helpers executed in BPF programs.  This facilitates
>> BPF socket destroy kfunc (introduced by follow-up patches) to execute from
>> BPF iterator programs.
> 
>> Previously, BPF iterators acquired the sock lock and sockets hash table
>> bucket lock while executing BPF programs. This prevented BPF helpers that
>> again acquire these locks to be executed from BPF iterators.  With the
>> batching approach, we acquire a bucket lock, batch all the bucket sockets,
>> and then release the bucket lock. This enables BPF or kernel helpers to
>> skip sock locking when invoked in the supported BPF contexts.
> 
>> The batching logic is similar to the logic implemented in TCP iterator:
>> https://lore.kernel.org/bpf/20210701200613.1036157-1-kafai@fb.com/.
> 
>> Suggested-by: Martin KaFai Lau <martin.lau@kernel.org>
>> Signed-off-by: Aditi Ghag <aditi.ghag@isovalent.com>
>> ---
>>  include/net/udp.h |   1 +
>>  net/ipv4/udp.c    | 255 ++++++++++++++++++++++++++++++++++++++++++++--
>>  2 files changed, 247 insertions(+), 9 deletions(-)
> 
>> diff --git a/include/net/udp.h b/include/net/udp.h
>> index de4b528522bb..d2999447d3f2 100644
>> --- a/include/net/udp.h
>> +++ b/include/net/udp.h
>> @@ -437,6 +437,7 @@ struct udp_seq_afinfo {
>>  struct udp_iter_state {
>>  	struct seq_net_private  p;
>>  	int			bucket;
>> +	int			offset;
>>  	struct udp_seq_afinfo	*bpf_seq_afinfo;
>>  };
> 
>> diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
>> index c605d171eb2d..58c620243e47 100644
>> --- a/net/ipv4/udp.c
>> +++ b/net/ipv4/udp.c
>> @@ -3152,6 +3152,171 @@ struct bpf_iter__udp {
>>  	int bucket __aligned(8);
>>  };
> 
>> +struct bpf_udp_iter_state {
>> +	struct udp_iter_state state;
>> +	unsigned int cur_sk;
>> +	unsigned int end_sk;
>> +	unsigned int max_sk;
>> +	struct sock **batch;
>> +	bool st_bucket_done;
>> +};
>> +
>> +static unsigned short seq_file_family(const struct seq_file *seq);
>> +static int bpf_iter_udp_realloc_batch(struct bpf_udp_iter_state *iter,
>> +				      unsigned int new_batch_sz);
>> +
>> +static inline bool seq_sk_match(struct seq_file *seq, const struct sock *sk)
>> +{
>> +	unsigned short family = seq_file_family(seq);
>> +
>> +	/* AF_UNSPEC is used as a match all */
>> +	return ((family == AF_UNSPEC || family == sk->sk_family) &&
>> +		net_eq(sock_net(sk), seq_file_net(seq)));
>> +}
>> +
>> +static struct sock *bpf_iter_udp_batch(struct seq_file *seq)
>> +{
>> +	struct bpf_udp_iter_state *iter = seq->private;
>> +	struct udp_iter_state *state = &iter->state;
>> +	struct net *net = seq_file_net(seq);
>> +	struct udp_seq_afinfo *afinfo = state->bpf_seq_afinfo;
>> +	struct udp_table *udptable;
>> +	struct sock *first_sk = NULL;
>> +	struct sock *sk;
>> +	unsigned int bucket_sks = 0;
>> +	bool resized = false;
>> +	int offset = 0;
>> +	int new_offset;
>> +
>> +	/* The current batch is done, so advance the bucket. */
>> +	if (iter->st_bucket_done) {
>> +		state->bucket++;
>> +		state->offset = 0;
>> +	}
>> +
>> +	udptable = udp_get_table_afinfo(afinfo, net);
>> +
>> +	if (state->bucket > udptable->mask) {
>> +		state->bucket = 0;
>> +		state->offset = 0;
>> +		return NULL;
>> +	}
>> +
>> +again:
>> +	/* New batch for the next bucket.
>> +	 * Iterate over the hash table to find a bucket with sockets matching
>> +	 * the iterator attributes, and return the first matching socket from
>> +	 * the bucket. The remaining matched sockets from the bucket are batched
>> +	 * before releasing the bucket lock. This allows BPF programs that are
>> +	 * called in seq_show to acquire the bucket lock if needed.
>> +	 */
>> +	iter->cur_sk = 0;
>> +	iter->end_sk = 0;
>> +	iter->st_bucket_done = false;
>> +	first_sk = NULL;
>> +	bucket_sks = 0;
>> +	offset = state->offset;
>> +	new_offset = offset;
>> +
>> +	for (; state->bucket <= udptable->mask; state->bucket++) {
>> +		struct udp_hslot *hslot = &udptable->hash[state->bucket];
>> +
>> +		if (hlist_empty(&hslot->head)) {
>> +			offset = 0;
>> +			continue;
>> +		}
>> +
>> +		spin_lock_bh(&hslot->lock);
>> +		/* Resume from the last saved position in a bucket before
>> +		 * iterator was stopped.
>> +		 */
>> +		while (offset-- > 0) {
>> +			sk_for_each(sk, &hslot->head)
>> +				continue;
>> +		}
>> +		sk_for_each(sk, &hslot->head) {
>> +			if (seq_sk_match(seq, sk)) {
>> +				if (!first_sk)
>> +					first_sk = sk;
>> +				if (iter->end_sk < iter->max_sk) {
>> +					sock_hold(sk);
>> +					iter->batch[iter->end_sk++] = sk;
>> +				}
>> +				bucket_sks++;
>> +			}
>> +			new_offset++;
>> +		}
>> +		spin_unlock_bh(&hslot->lock);
>> +
>> +		if (first_sk)
>> +			break;
>> +
>> +		/* Reset the current bucket's offset before moving to the next bucket. */
>> +		offset = 0;
>> +		new_offset = 0;
>> +	}
>> +
>> +	/* All done: no batch made. */
>> +	if (!first_sk)
>> +		goto ret;
>> +
>> +	if (iter->end_sk == bucket_sks) {
>> +		/* Batching is done for the current bucket; return the first
>> +		 * socket to be iterated from the batch.
>> +		 */
>> +		iter->st_bucket_done = true;
>> +		goto ret;
>> +	}
>> +	if (!resized && !bpf_iter_udp_realloc_batch(iter, bucket_sks * 3 / 2)) {
>> +		resized = true;
>> +		/* Go back to the previous bucket to resize its batch. */
>> +		state->bucket--;
>> +		goto again;
>> +	}
>> +ret:
>> +	state->offset = new_offset;
>> +	return first_sk;
>> +}
>> +
>> +static void *bpf_iter_udp_seq_next(struct seq_file *seq, void *v, loff_t *pos)
>> +{
>> +	struct bpf_udp_iter_state *iter = seq->private;
>> +	struct udp_iter_state *state = &iter->state;
>> +	struct sock *sk;
>> +
>> +	/* Whenever seq_next() is called, the iter->cur_sk is
>> +	 * done with seq_show(), so unref the iter->cur_sk.
>> +	 */
>> +	if (iter->cur_sk < iter->end_sk) {
>> +		sock_put(iter->batch[iter->cur_sk++]);
>> +		++state->offset;
>> +	}
>> +
>> +	/* After updating iter->cur_sk, check if there are more sockets
>> +	 * available in the current bucket batch.
>> +	 */
>> +	if (iter->cur_sk < iter->end_sk) {
>> +		sk = iter->batch[iter->cur_sk];
>> +	} else {
>> +		// Prepare a new batch.
>> +		sk = bpf_iter_udp_batch(seq);
>> +	}
>> +
>> +	++*pos;
>> +	return sk;
>> +}
>> +
>> +static void *bpf_iter_udp_seq_start(struct seq_file *seq, loff_t *pos)
>> +{
>> +	/* bpf iter does not support lseek, so it always
>> +	 * continue from where it was stop()-ped.
>> +	 */
>> +	if (*pos)
>> +		return bpf_iter_udp_batch(seq);
>> +
>> +	return SEQ_START_TOKEN;
>> +}
>> +
>>  static int udp_prog_seq_show(struct bpf_prog *prog, struct bpf_iter_meta *meta,
>>  			     struct udp_sock *udp_sk, uid_t uid, int bucket)
>>  {
>> @@ -3172,18 +3337,38 @@ static int bpf_iter_udp_seq_show(struct seq_file *seq, void *v)
>>  	struct bpf_prog *prog;
>>  	struct sock *sk = v;
>>  	uid_t uid;
>> +	bool slow;
>> +	int rc;
> 
>>  	if (v == SEQ_START_TOKEN)
>>  		return 0;
> 
> 
> [..]
> 
>> +	slow = lock_sock_fast(sk);
>> +
>> +	if (unlikely(sk_unhashed(sk))) {
>> +		rc = SEQ_SKIP;
>> +		goto unlock;
>> +	}
>> +
> 
> Should we use non-fast version here for consistency with tcp?

We could, but I don't see a problem with acquiring fast version for UDP so we could just stick with it. The TCP change warrants a code comment though, I'll add it in the next reversion. 

> 
> 
>>  	uid = from_kuid_munged(seq_user_ns(seq), sock_i_uid(sk));
>>  	meta.seq = seq;
>>  	prog = bpf_iter_get_info(&meta, false);
>> -	return udp_prog_seq_show(prog, &meta, v, uid, state->bucket);
>> +	rc = udp_prog_seq_show(prog, &meta, v, uid, state->bucket);
>> +
>> +unlock:
>> +	unlock_sock_fast(sk, slow);
>> +	return rc;
>> +}
>> +
>> +static void bpf_iter_udp_unref_batch(struct bpf_udp_iter_state *iter)
>> +{
>> +	while (iter->cur_sk < iter->end_sk)
>> +		sock_put(iter->batch[iter->cur_sk++]);
>>  }
> 
>>  static void bpf_iter_udp_seq_stop(struct seq_file *seq, void *v)
>>  {
>> +	struct bpf_udp_iter_state *iter = seq->private;
>>  	struct bpf_iter_meta meta;
>>  	struct bpf_prog *prog;
> 
>> @@ -3194,15 +3379,31 @@ static void bpf_iter_udp_seq_stop(struct seq_file *seq, void *v)
>>  			(void)udp_prog_seq_show(prog, &meta, v, 0, 0);
>>  	}
> 
>> -	udp_seq_stop(seq, v);
>> +	if (iter->cur_sk < iter->end_sk) {
>> +		bpf_iter_udp_unref_batch(iter);
>> +		iter->st_bucket_done = false;
>> +	}
>>  }
> 
>>  static const struct seq_operations bpf_iter_udp_seq_ops = {
>> -	.start		= udp_seq_start,
>> -	.next		= udp_seq_next,
>> +	.start		= bpf_iter_udp_seq_start,
>> +	.next		= bpf_iter_udp_seq_next,
>>  	.stop		= bpf_iter_udp_seq_stop,
>>  	.show		= bpf_iter_udp_seq_show,
>>  };
>> +
>> +static unsigned short seq_file_family(const struct seq_file *seq)
>> +{
>> +	const struct udp_seq_afinfo *afinfo;
>> +
>> +	/* BPF iterator: bpf programs to filter sockets. */
>> +	if (seq->op == &bpf_iter_udp_seq_ops)
>> +		return AF_UNSPEC;
>> +
>> +	/* Proc fs iterator */
>> +	afinfo = pde_data(file_inode(seq->file));
>> +	return afinfo->family;
>> +}
>>  #endif
> 
>>  const struct seq_operations udp_seq_ops = {
>> @@ -3413,9 +3614,30 @@ static struct pernet_operations __net_initdata udp_sysctl_ops = {
>>  DEFINE_BPF_ITER_FUNC(udp, struct bpf_iter_meta *meta,
>>  		     struct udp_sock *udp_sk, uid_t uid, int bucket)
> 
>> +static int bpf_iter_udp_realloc_batch(struct bpf_udp_iter_state *iter,
>> +				      unsigned int new_batch_sz)
>> +{
>> +	struct sock **new_batch;
>> +
>> +	new_batch = kvmalloc_array(new_batch_sz, sizeof(*new_batch),
>> +				   GFP_USER | __GFP_NOWARN);
>> +	if (!new_batch)
>> +		return -ENOMEM;
>> +
>> +	bpf_iter_udp_unref_batch(iter);
>> +	kvfree(iter->batch);
>> +	iter->batch = new_batch;
>> +	iter->max_sk = new_batch_sz;
>> +
>> +	return 0;
>> +}
>> +
>> +#define INIT_BATCH_SZ 16
>> +
>>  static int bpf_iter_init_udp(void *priv_data, struct bpf_iter_aux_info *aux)
>>  {
>> -	struct udp_iter_state *st = priv_data;
>> +	struct bpf_udp_iter_state *iter = priv_data;
>> +	struct udp_iter_state *st = &iter->state;
>>  	struct udp_seq_afinfo *afinfo;
>>  	int ret;
> 
>> @@ -3427,24 +3649,39 @@ static int bpf_iter_init_udp(void *priv_data, struct bpf_iter_aux_info *aux)
>>  	afinfo->udp_table = NULL;
>>  	st->bpf_seq_afinfo = afinfo;
>>  	ret = bpf_iter_init_seq_net(priv_data, aux);
>> -	if (ret)
>> +	if (ret) {
>>  		kfree(afinfo);
>> +		return ret;
>> +	}
>> +	ret = bpf_iter_udp_realloc_batch(iter, INIT_BATCH_SZ);
>> +	if (ret) {
>> +		bpf_iter_fini_seq_net(priv_data);
>> +		return ret;
>> +	}
>> +	iter->cur_sk = 0;
>> +	iter->end_sk = 0;
>> +	iter->st_bucket_done = false;
>> +	st->bucket = 0;
>> +	st->offset = 0;
>> +
>>  	return ret;
>>  }
> 
>>  static void bpf_iter_fini_udp(void *priv_data)
>>  {
>> -	struct udp_iter_state *st = priv_data;
>> +	struct bpf_udp_iter_state *iter = priv_data;
>> +	struct udp_iter_state *st = &iter->state;
> 
>> -	kfree(st->bpf_seq_afinfo);
>>  	bpf_iter_fini_seq_net(priv_data);
>> +	kfree(st->bpf_seq_afinfo);
>> +	kvfree(iter->batch);
>>  }
> 
>>  static const struct bpf_iter_seq_info udp_seq_info = {
>>  	.seq_ops		= &bpf_iter_udp_seq_ops,
>>  	.init_seq_private	= bpf_iter_init_udp,
>>  	.fini_seq_private	= bpf_iter_fini_udp,
>> -	.seq_priv_size		= sizeof(struct udp_iter_state),
>> +	.seq_priv_size		= sizeof(struct bpf_udp_iter_state),
>>  };
> 
>>  static struct bpf_iter_reg udp_reg_info = {
>> --
>> 2.34.1
Stanislav Fomichev March 27, 2023, 4:52 p.m. UTC | #3
On 03/27, Aditi Ghag wrote:


> > On Mar 24, 2023, at 2:56 PM, Stanislav Fomichev <sdf@google.com> wrote:
> >
> > On 03/23, Aditi Ghag wrote:
> >> Batch UDP sockets from BPF iterator that allows for overlapping locking
> >> semantics in BPF/kernel helpers executed in BPF programs.  This  
> facilitates
> >> BPF socket destroy kfunc (introduced by follow-up patches) to execute  
> from
> >> BPF iterator programs.
> >
> >> Previously, BPF iterators acquired the sock lock and sockets hash table
> >> bucket lock while executing BPF programs. This prevented BPF helpers  
> that
> >> again acquire these locks to be executed from BPF iterators.  With the
> >> batching approach, we acquire a bucket lock, batch all the bucket  
> sockets,
> >> and then release the bucket lock. This enables BPF or kernel helpers to
> >> skip sock locking when invoked in the supported BPF contexts.
> >
> >> The batching logic is similar to the logic implemented in TCP iterator:
> >> https://lore.kernel.org/bpf/20210701200613.1036157-1-kafai@fb.com/.
> >
> >> Suggested-by: Martin KaFai Lau <martin.lau@kernel.org>
> >> Signed-off-by: Aditi Ghag <aditi.ghag@isovalent.com>
> >> ---
> >>  include/net/udp.h |   1 +
> >>  net/ipv4/udp.c    | 255 ++++++++++++++++++++++++++++++++++++++++++++--
> >>  2 files changed, 247 insertions(+), 9 deletions(-)
> >
> >> diff --git a/include/net/udp.h b/include/net/udp.h
> >> index de4b528522bb..d2999447d3f2 100644
> >> --- a/include/net/udp.h
> >> +++ b/include/net/udp.h
> >> @@ -437,6 +437,7 @@ struct udp_seq_afinfo {
> >>  struct udp_iter_state {
> >>  	struct seq_net_private  p;
> >>  	int			bucket;
> >> +	int			offset;
> >>  	struct udp_seq_afinfo	*bpf_seq_afinfo;
> >>  };
> >
> >> diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
> >> index c605d171eb2d..58c620243e47 100644
> >> --- a/net/ipv4/udp.c
> >> +++ b/net/ipv4/udp.c
> >> @@ -3152,6 +3152,171 @@ struct bpf_iter__udp {
> >>  	int bucket __aligned(8);
> >>  };
> >
> >> +struct bpf_udp_iter_state {
> >> +	struct udp_iter_state state;
> >> +	unsigned int cur_sk;
> >> +	unsigned int end_sk;
> >> +	unsigned int max_sk;
> >> +	struct sock **batch;
> >> +	bool st_bucket_done;
> >> +};
> >> +
> >> +static unsigned short seq_file_family(const struct seq_file *seq);
> >> +static int bpf_iter_udp_realloc_batch(struct bpf_udp_iter_state *iter,
> >> +				      unsigned int new_batch_sz);
> >> +
> >> +static inline bool seq_sk_match(struct seq_file *seq, const struct  
> sock *sk)
> >> +{
> >> +	unsigned short family = seq_file_family(seq);
> >> +
> >> +	/* AF_UNSPEC is used as a match all */
> >> +	return ((family == AF_UNSPEC || family == sk->sk_family) &&
> >> +		net_eq(sock_net(sk), seq_file_net(seq)));
> >> +}
> >> +
> >> +static struct sock *bpf_iter_udp_batch(struct seq_file *seq)
> >> +{
> >> +	struct bpf_udp_iter_state *iter = seq->private;
> >> +	struct udp_iter_state *state = &iter->state;
> >> +	struct net *net = seq_file_net(seq);
> >> +	struct udp_seq_afinfo *afinfo = state->bpf_seq_afinfo;
> >> +	struct udp_table *udptable;
> >> +	struct sock *first_sk = NULL;
> >> +	struct sock *sk;
> >> +	unsigned int bucket_sks = 0;
> >> +	bool resized = false;
> >> +	int offset = 0;
> >> +	int new_offset;
> >> +
> >> +	/* The current batch is done, so advance the bucket. */
> >> +	if (iter->st_bucket_done) {
> >> +		state->bucket++;
> >> +		state->offset = 0;
> >> +	}
> >> +
> >> +	udptable = udp_get_table_afinfo(afinfo, net);
> >> +
> >> +	if (state->bucket > udptable->mask) {
> >> +		state->bucket = 0;
> >> +		state->offset = 0;
> >> +		return NULL;
> >> +	}
> >> +
> >> +again:
> >> +	/* New batch for the next bucket.
> >> +	 * Iterate over the hash table to find a bucket with sockets matching
> >> +	 * the iterator attributes, and return the first matching socket from
> >> +	 * the bucket. The remaining matched sockets from the bucket are  
> batched
> >> +	 * before releasing the bucket lock. This allows BPF programs that  
> are
> >> +	 * called in seq_show to acquire the bucket lock if needed.
> >> +	 */
> >> +	iter->cur_sk = 0;
> >> +	iter->end_sk = 0;
> >> +	iter->st_bucket_done = false;
> >> +	first_sk = NULL;
> >> +	bucket_sks = 0;
> >> +	offset = state->offset;
> >> +	new_offset = offset;
> >> +
> >> +	for (; state->bucket <= udptable->mask; state->bucket++) {
> >> +		struct udp_hslot *hslot = &udptable->hash[state->bucket];
> >> +
> >> +		if (hlist_empty(&hslot->head)) {
> >> +			offset = 0;
> >> +			continue;
> >> +		}
> >> +
> >> +		spin_lock_bh(&hslot->lock);
> >> +		/* Resume from the last saved position in a bucket before
> >> +		 * iterator was stopped.
> >> +		 */
> >> +		while (offset-- > 0) {
> >> +			sk_for_each(sk, &hslot->head)
> >> +				continue;
> >> +		}
> >> +		sk_for_each(sk, &hslot->head) {
> >> +			if (seq_sk_match(seq, sk)) {
> >> +				if (!first_sk)
> >> +					first_sk = sk;
> >> +				if (iter->end_sk < iter->max_sk) {
> >> +					sock_hold(sk);
> >> +					iter->batch[iter->end_sk++] = sk;
> >> +				}
> >> +				bucket_sks++;
> >> +			}
> >> +			new_offset++;
> >> +		}
> >> +		spin_unlock_bh(&hslot->lock);
> >> +
> >> +		if (first_sk)
> >> +			break;
> >> +
> >> +		/* Reset the current bucket's offset before moving to the next  
> bucket. */
> >> +		offset = 0;
> >> +		new_offset = 0;
> >> +	}
> >> +
> >> +	/* All done: no batch made. */
> >> +	if (!first_sk)
> >> +		goto ret;
> >> +
> >> +	if (iter->end_sk == bucket_sks) {
> >> +		/* Batching is done for the current bucket; return the first
> >> +		 * socket to be iterated from the batch.
> >> +		 */
> >> +		iter->st_bucket_done = true;
> >> +		goto ret;
> >> +	}
> >> +	if (!resized && !bpf_iter_udp_realloc_batch(iter, bucket_sks * 3 /  
> 2)) {
> >> +		resized = true;
> >> +		/* Go back to the previous bucket to resize its batch. */
> >> +		state->bucket--;
> >> +		goto again;
> >> +	}
> >> +ret:
> >> +	state->offset = new_offset;
> >> +	return first_sk;
> >> +}
> >> +
> >> +static void *bpf_iter_udp_seq_next(struct seq_file *seq, void *v,  
> loff_t *pos)
> >> +{
> >> +	struct bpf_udp_iter_state *iter = seq->private;
> >> +	struct udp_iter_state *state = &iter->state;
> >> +	struct sock *sk;
> >> +
> >> +	/* Whenever seq_next() is called, the iter->cur_sk is
> >> +	 * done with seq_show(), so unref the iter->cur_sk.
> >> +	 */
> >> +	if (iter->cur_sk < iter->end_sk) {
> >> +		sock_put(iter->batch[iter->cur_sk++]);
> >> +		++state->offset;
> >> +	}
> >> +
> >> +	/* After updating iter->cur_sk, check if there are more sockets
> >> +	 * available in the current bucket batch.
> >> +	 */
> >> +	if (iter->cur_sk < iter->end_sk) {
> >> +		sk = iter->batch[iter->cur_sk];
> >> +	} else {
> >> +		// Prepare a new batch.
> >> +		sk = bpf_iter_udp_batch(seq);
> >> +	}
> >> +
> >> +	++*pos;
> >> +	return sk;
> >> +}
> >> +
> >> +static void *bpf_iter_udp_seq_start(struct seq_file *seq, loff_t *pos)
> >> +{
> >> +	/* bpf iter does not support lseek, so it always
> >> +	 * continue from where it was stop()-ped.
> >> +	 */
> >> +	if (*pos)
> >> +		return bpf_iter_udp_batch(seq);
> >> +
> >> +	return SEQ_START_TOKEN;
> >> +}
> >> +
> >>  static int udp_prog_seq_show(struct bpf_prog *prog, struct  
> bpf_iter_meta *meta,
> >>  			     struct udp_sock *udp_sk, uid_t uid, int bucket)
> >>  {
> >> @@ -3172,18 +3337,38 @@ static int bpf_iter_udp_seq_show(struct  
> seq_file *seq, void *v)
> >>  	struct bpf_prog *prog;
> >>  	struct sock *sk = v;
> >>  	uid_t uid;
> >> +	bool slow;
> >> +	int rc;
> >
> >>  	if (v == SEQ_START_TOKEN)
> >>  		return 0;
> >
> >
> > [..]
> >
> >> +	slow = lock_sock_fast(sk);
> >> +
> >> +	if (unlikely(sk_unhashed(sk))) {
> >> +		rc = SEQ_SKIP;
> >> +		goto unlock;
> >> +	}
> >> +
> >
> > Should we use non-fast version here for consistency with tcp?

> We could, but I don't see a problem with acquiring fast version for UDP  
> so we could just stick with it. The TCP change warrants a code comment  
> though, I'll add it in the next reversion.

lock_sock_fast is an exception and we should have a good reason to use
it in a particular place. It blocks bh (rx softirq) and doesn't
consume the backlog on unlock.

$ grep -ri lock_sock_fast . | wc -l
60

$ grep -ri lock_sock . | wc -l
1075 # this includes 60 from the above, but it doesn't matter

So unless you have a good reason to use it (and not a mere "why not"),
lets use regular lock_sock here?

> >
> >
> >>  	uid = from_kuid_munged(seq_user_ns(seq), sock_i_uid(sk));
> >>  	meta.seq = seq;
> >>  	prog = bpf_iter_get_info(&meta, false);
> >> -	return udp_prog_seq_show(prog, &meta, v, uid, state->bucket);
> >> +	rc = udp_prog_seq_show(prog, &meta, v, uid, state->bucket);
> >> +
> >> +unlock:
> >> +	unlock_sock_fast(sk, slow);
> >> +	return rc;
> >> +}
> >> +
> >> +static void bpf_iter_udp_unref_batch(struct bpf_udp_iter_state *iter)
> >> +{
> >> +	while (iter->cur_sk < iter->end_sk)
> >> +		sock_put(iter->batch[iter->cur_sk++]);
> >>  }
> >
> >>  static void bpf_iter_udp_seq_stop(struct seq_file *seq, void *v)
> >>  {
> >> +	struct bpf_udp_iter_state *iter = seq->private;
> >>  	struct bpf_iter_meta meta;
> >>  	struct bpf_prog *prog;
> >
> >> @@ -3194,15 +3379,31 @@ static void bpf_iter_udp_seq_stop(struct  
> seq_file *seq, void *v)
> >>  			(void)udp_prog_seq_show(prog, &meta, v, 0, 0);
> >>  	}
> >
> >> -	udp_seq_stop(seq, v);
> >> +	if (iter->cur_sk < iter->end_sk) {
> >> +		bpf_iter_udp_unref_batch(iter);
> >> +		iter->st_bucket_done = false;
> >> +	}
> >>  }
> >
> >>  static const struct seq_operations bpf_iter_udp_seq_ops = {
> >> -	.start		= udp_seq_start,
> >> -	.next		= udp_seq_next,
> >> +	.start		= bpf_iter_udp_seq_start,
> >> +	.next		= bpf_iter_udp_seq_next,
> >>  	.stop		= bpf_iter_udp_seq_stop,
> >>  	.show		= bpf_iter_udp_seq_show,
> >>  };
> >> +
> >> +static unsigned short seq_file_family(const struct seq_file *seq)
> >> +{
> >> +	const struct udp_seq_afinfo *afinfo;
> >> +
> >> +	/* BPF iterator: bpf programs to filter sockets. */
> >> +	if (seq->op == &bpf_iter_udp_seq_ops)
> >> +		return AF_UNSPEC;
> >> +
> >> +	/* Proc fs iterator */
> >> +	afinfo = pde_data(file_inode(seq->file));
> >> +	return afinfo->family;
> >> +}
> >>  #endif
> >
> >>  const struct seq_operations udp_seq_ops = {
> >> @@ -3413,9 +3614,30 @@ static struct pernet_operations __net_initdata  
> udp_sysctl_ops = {
> >>  DEFINE_BPF_ITER_FUNC(udp, struct bpf_iter_meta *meta,
> >>  		     struct udp_sock *udp_sk, uid_t uid, int bucket)
> >
> >> +static int bpf_iter_udp_realloc_batch(struct bpf_udp_iter_state *iter,
> >> +				      unsigned int new_batch_sz)
> >> +{
> >> +	struct sock **new_batch;
> >> +
> >> +	new_batch = kvmalloc_array(new_batch_sz, sizeof(*new_batch),
> >> +				   GFP_USER | __GFP_NOWARN);
> >> +	if (!new_batch)
> >> +		return -ENOMEM;
> >> +
> >> +	bpf_iter_udp_unref_batch(iter);
> >> +	kvfree(iter->batch);
> >> +	iter->batch = new_batch;
> >> +	iter->max_sk = new_batch_sz;
> >> +
> >> +	return 0;
> >> +}
> >> +
> >> +#define INIT_BATCH_SZ 16
> >> +
> >>  static int bpf_iter_init_udp(void *priv_data, struct  
> bpf_iter_aux_info *aux)
> >>  {
> >> -	struct udp_iter_state *st = priv_data;
> >> +	struct bpf_udp_iter_state *iter = priv_data;
> >> +	struct udp_iter_state *st = &iter->state;
> >>  	struct udp_seq_afinfo *afinfo;
> >>  	int ret;
> >
> >> @@ -3427,24 +3649,39 @@ static int bpf_iter_init_udp(void *priv_data,  
> struct bpf_iter_aux_info *aux)
> >>  	afinfo->udp_table = NULL;
> >>  	st->bpf_seq_afinfo = afinfo;
> >>  	ret = bpf_iter_init_seq_net(priv_data, aux);
> >> -	if (ret)
> >> +	if (ret) {
> >>  		kfree(afinfo);
> >> +		return ret;
> >> +	}
> >> +	ret = bpf_iter_udp_realloc_batch(iter, INIT_BATCH_SZ);
> >> +	if (ret) {
> >> +		bpf_iter_fini_seq_net(priv_data);
> >> +		return ret;
> >> +	}
> >> +	iter->cur_sk = 0;
> >> +	iter->end_sk = 0;
> >> +	iter->st_bucket_done = false;
> >> +	st->bucket = 0;
> >> +	st->offset = 0;
> >> +
> >>  	return ret;
> >>  }
> >
> >>  static void bpf_iter_fini_udp(void *priv_data)
> >>  {
> >> -	struct udp_iter_state *st = priv_data;
> >> +	struct bpf_udp_iter_state *iter = priv_data;
> >> +	struct udp_iter_state *st = &iter->state;
> >
> >> -	kfree(st->bpf_seq_afinfo);
> >>  	bpf_iter_fini_seq_net(priv_data);
> >> +	kfree(st->bpf_seq_afinfo);
> >> +	kvfree(iter->batch);
> >>  }
> >
> >>  static const struct bpf_iter_seq_info udp_seq_info = {
> >>  	.seq_ops		= &bpf_iter_udp_seq_ops,
> >>  	.init_seq_private	= bpf_iter_init_udp,
> >>  	.fini_seq_private	= bpf_iter_fini_udp,
> >> -	.seq_priv_size		= sizeof(struct udp_iter_state),
> >> +	.seq_priv_size		= sizeof(struct bpf_udp_iter_state),
> >>  };
> >
> >>  static struct bpf_iter_reg udp_reg_info = {
> >> --
> >> 2.34.1
Martin KaFai Lau March 27, 2023, 10:28 p.m. UTC | #4
On 3/23/23 1:06 PM, Aditi Ghag wrote:
> Batch UDP sockets from BPF iterator that allows for overlapping locking
> semantics in BPF/kernel helpers executed in BPF programs.  This facilitates
> BPF socket destroy kfunc (introduced by follow-up patches) to execute from
> BPF iterator programs.
> 
> Previously, BPF iterators acquired the sock lock and sockets hash table
> bucket lock while executing BPF programs. This prevented BPF helpers that
> again acquire these locks to be executed from BPF iterators.  With the
> batching approach, we acquire a bucket lock, batch all the bucket sockets,
> and then release the bucket lock. This enables BPF or kernel helpers to
> skip sock locking when invoked in the supported BPF contexts.
> 
> The batching logic is similar to the logic implemented in TCP iterator:
> https://lore.kernel.org/bpf/20210701200613.1036157-1-kafai@fb.com/.
> 
> Suggested-by: Martin KaFai Lau <martin.lau@kernel.org>
> Signed-off-by: Aditi Ghag <aditi.ghag@isovalent.com>
> ---
>   include/net/udp.h |   1 +
>   net/ipv4/udp.c    | 255 ++++++++++++++++++++++++++++++++++++++++++++--
>   2 files changed, 247 insertions(+), 9 deletions(-)
> 
> diff --git a/include/net/udp.h b/include/net/udp.h
> index de4b528522bb..d2999447d3f2 100644
> --- a/include/net/udp.h
> +++ b/include/net/udp.h
> @@ -437,6 +437,7 @@ struct udp_seq_afinfo {
>   struct udp_iter_state {
>   	struct seq_net_private  p;
>   	int			bucket;
> +	int			offset;

offset should be moved to 'struct bpf_udp_iter_state' instead. It is specific to 
bpf_iter only.

>   	struct udp_seq_afinfo	*bpf_seq_afinfo;
>   };
>   
> diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
> index c605d171eb2d..58c620243e47 100644
> --- a/net/ipv4/udp.c
> +++ b/net/ipv4/udp.c
> @@ -3152,6 +3152,171 @@ struct bpf_iter__udp {
>   	int bucket __aligned(8);
>   };
>   
> +struct bpf_udp_iter_state {
> +	struct udp_iter_state state;
> +	unsigned int cur_sk;
> +	unsigned int end_sk;
> +	unsigned int max_sk;
> +	struct sock **batch;
> +	bool st_bucket_done;
> +};
> +
> +static unsigned short seq_file_family(const struct seq_file *seq);
> +static int bpf_iter_udp_realloc_batch(struct bpf_udp_iter_state *iter,
> +				      unsigned int new_batch_sz);
> +
> +static inline bool seq_sk_match(struct seq_file *seq, const struct sock *sk)
> +{
> +	unsigned short family = seq_file_family(seq);
> +
> +	/* AF_UNSPEC is used as a match all */
> +	return ((family == AF_UNSPEC || family == sk->sk_family) &&
> +		net_eq(sock_net(sk), seq_file_net(seq)));
> +}
> +
> +static struct sock *bpf_iter_udp_batch(struct seq_file *seq)
> +{
> +	struct bpf_udp_iter_state *iter = seq->private;
> +	struct udp_iter_state *state = &iter->state;
> +	struct net *net = seq_file_net(seq);
> +	struct udp_seq_afinfo *afinfo = state->bpf_seq_afinfo;
> +	struct udp_table *udptable;
> +	struct sock *first_sk = NULL;
> +	struct sock *sk;
> +	unsigned int bucket_sks = 0;
> +	bool resized = false;
> +	int offset = 0;
> +	int new_offset;
> +
> +	/* The current batch is done, so advance the bucket. */
> +	if (iter->st_bucket_done) {
> +		state->bucket++;
> +		state->offset = 0;
> +	}
> +
> +	udptable = udp_get_table_afinfo(afinfo, net);
> +
> +	if (state->bucket > udptable->mask) {
> +		state->bucket = 0;
> +		state->offset = 0;
> +		return NULL;
> +	}
> +
> +again:
> +	/* New batch for the next bucket.
> +	 * Iterate over the hash table to find a bucket with sockets matching
> +	 * the iterator attributes, and return the first matching socket from
> +	 * the bucket. The remaining matched sockets from the bucket are batched
> +	 * before releasing the bucket lock. This allows BPF programs that are
> +	 * called in seq_show to acquire the bucket lock if needed.
> +	 */
> +	iter->cur_sk = 0;
> +	iter->end_sk = 0;
> +	iter->st_bucket_done = false;
> +	first_sk = NULL;
> +	bucket_sks = 0;
> +	offset = state->offset;
> +	new_offset = offset;
> +
> +	for (; state->bucket <= udptable->mask; state->bucket++) {
> +		struct udp_hslot *hslot = &udptable->hash[state->bucket];

Use udptable->hash"2" which is hashed by addr and port. It will help to get a 
smaller batch. It was the comment given in v2.

> +
> +		if (hlist_empty(&hslot->head)) {
> +			offset = 0;
> +			continue;
> +		}
> +
> +		spin_lock_bh(&hslot->lock);
> +		/* Resume from the last saved position in a bucket before
> +		 * iterator was stopped.
> +		 */
> +		while (offset-- > 0) {
> +			sk_for_each(sk, &hslot->head)
> +				continue;
> +		}

hmm... how does the above while loop and sk_for_each loop actually work?

> +		sk_for_each(sk, &hslot->head) {

Here starts from the beginning of the hslot->head again. doesn't look right also.

Am I missing something here?

> +			if (seq_sk_match(seq, sk)) {
> +				if (!first_sk)
> +					first_sk = sk;
> +				if (iter->end_sk < iter->max_sk) {
> +					sock_hold(sk);
> +					iter->batch[iter->end_sk++] = sk;
> +				}
> +				bucket_sks++;
> +			}
> +			new_offset++;

And this new_offset is outside of seq_sk_match, so it is not counting for the 
seq_file_net(seq) netns alone.

> +		}
> +		spin_unlock_bh(&hslot->lock);
> +
> +		if (first_sk)
> +			break;
> +
> +		/* Reset the current bucket's offset before moving to the next bucket. */
> +		offset = 0;
> +		new_offset = 0;
> +	}
> +
> +	/* All done: no batch made. */
> +	if (!first_sk)
> +		goto ret;
> +
> +	if (iter->end_sk == bucket_sks) {
> +		/* Batching is done for the current bucket; return the first
> +		 * socket to be iterated from the batch.
> +		 */
> +		iter->st_bucket_done = true;
> +		goto ret;
> +	}
> +	if (!resized && !bpf_iter_udp_realloc_batch(iter, bucket_sks * 3 / 2)) {
> +		resized = true;
> +		/* Go back to the previous bucket to resize its batch. */
> +		state->bucket--;
> +		goto again;
> +	}
> +ret:
> +	state->offset = new_offset;
> +	return first_sk;
> +}
> +
> +static void *bpf_iter_udp_seq_next(struct seq_file *seq, void *v, loff_t *pos)
> +{
> +	struct bpf_udp_iter_state *iter = seq->private;
> +	struct udp_iter_state *state = &iter->state;
> +	struct sock *sk;
> +
> +	/* Whenever seq_next() is called, the iter->cur_sk is
> +	 * done with seq_show(), so unref the iter->cur_sk.
> +	 */
> +	if (iter->cur_sk < iter->end_sk) {
> +		sock_put(iter->batch[iter->cur_sk++]);
> +		++state->offset;

but then,
if I read it correctly, this offset counting is only for netns specific to 
seq_file_net(seq) because batch is specific to seq_file_net(net). Is it going to 
work?

> +	}
> +
> +	/* After updating iter->cur_sk, check if there are more sockets
> +	 * available in the current bucket batch.
> +	 */
> +	if (iter->cur_sk < iter->end_sk) {
> +		sk = iter->batch[iter->cur_sk];
> +	} else {
> +		// Prepare a new batch.
> +		sk = bpf_iter_udp_batch(seq);
> +	}
> +
> +	++*pos;
> +	return sk;
> +}
> +
> +static void *bpf_iter_udp_seq_start(struct seq_file *seq, loff_t *pos)
> +{
> +	/* bpf iter does not support lseek, so it always
> +	 * continue from where it was stop()-ped.
> +	 */
> +	if (*pos)
> +		return bpf_iter_udp_batch(seq);
> +
> +	return SEQ_START_TOKEN;
> +}
> +
>   static int udp_prog_seq_show(struct bpf_prog *prog, struct bpf_iter_meta *meta,
>   			     struct udp_sock *udp_sk, uid_t uid, int bucket)
>   {
> @@ -3172,18 +3337,38 @@ static int bpf_iter_udp_seq_show(struct seq_file *seq, void *v)
>   	struct bpf_prog *prog;
>   	struct sock *sk = v;
>   	uid_t uid;
> +	bool slow;
> +	int rc;
>   
>   	if (v == SEQ_START_TOKEN)
>   		return 0;
>   
> +	slow = lock_sock_fast(sk);
> +
> +	if (unlikely(sk_unhashed(sk))) {
> +		rc = SEQ_SKIP;
> +		goto unlock;
> +	}
> +
>   	uid = from_kuid_munged(seq_user_ns(seq), sock_i_uid(sk));
>   	meta.seq = seq;
>   	prog = bpf_iter_get_info(&meta, false);
> -	return udp_prog_seq_show(prog, &meta, v, uid, state->bucket);
> +	rc = udp_prog_seq_show(prog, &meta, v, uid, state->bucket);
> +
> +unlock:
> +	unlock_sock_fast(sk, slow);
> +	return rc;
> +}
> +
> +static void bpf_iter_udp_unref_batch(struct bpf_udp_iter_state *iter)

nit. Please use the same naming as in tcp-iter and unix-iter, so 
bpf_iter_udp_put_batch().

> +{
> +	while (iter->cur_sk < iter->end_sk)
> +		sock_put(iter->batch[iter->cur_sk++]);
>   }
>   
>   static void bpf_iter_udp_seq_stop(struct seq_file *seq, void *v)
>   {
> +	struct bpf_udp_iter_state *iter = seq->private;
>   	struct bpf_iter_meta meta;
>   	struct bpf_prog *prog;
>   
> @@ -3194,15 +3379,31 @@ static void bpf_iter_udp_seq_stop(struct seq_file *seq, void *v)
>   			(void)udp_prog_seq_show(prog, &meta, v, 0, 0);
>   	}
>   
> -	udp_seq_stop(seq, v);
> +	if (iter->cur_sk < iter->end_sk) {
> +		bpf_iter_udp_unref_batch(iter);
> +		iter->st_bucket_done = false;
> +	}
>   }
>   
>   static const struct seq_operations bpf_iter_udp_seq_ops = {
> -	.start		= udp_seq_start,
> -	.next		= udp_seq_next,
> +	.start		= bpf_iter_udp_seq_start,
> +	.next		= bpf_iter_udp_seq_next,
>   	.stop		= bpf_iter_udp_seq_stop,
>   	.show		= bpf_iter_udp_seq_show,
>   };
> +
> +static unsigned short seq_file_family(const struct seq_file *seq)
> +{
> +	const struct udp_seq_afinfo *afinfo;
> +
> +	/* BPF iterator: bpf programs to filter sockets. */
> +	if (seq->op == &bpf_iter_udp_seq_ops)
> +		return AF_UNSPEC;
> +
> +	/* Proc fs iterator */
> +	afinfo = pde_data(file_inode(seq->file));
> +	return afinfo->family;
> +}
>   #endif
>   
>   const struct seq_operations udp_seq_ops = {
> @@ -3413,9 +3614,30 @@ static struct pernet_operations __net_initdata udp_sysctl_ops = {
>   DEFINE_BPF_ITER_FUNC(udp, struct bpf_iter_meta *meta,
>   		     struct udp_sock *udp_sk, uid_t uid, int bucket)
>   
> +static int bpf_iter_udp_realloc_batch(struct bpf_udp_iter_state *iter,
> +				      unsigned int new_batch_sz)
> +{
> +	struct sock **new_batch;
> +
> +	new_batch = kvmalloc_array(new_batch_sz, sizeof(*new_batch),
> +				   GFP_USER | __GFP_NOWARN);
> +	if (!new_batch)
> +		return -ENOMEM;
> +
> +	bpf_iter_udp_unref_batch(iter);
> +	kvfree(iter->batch);
> +	iter->batch = new_batch;
> +	iter->max_sk = new_batch_sz;
> +
> +	return 0;
> +}
> +
> +#define INIT_BATCH_SZ 16
> +
>   static int bpf_iter_init_udp(void *priv_data, struct bpf_iter_aux_info *aux)
>   {
> -	struct udp_iter_state *st = priv_data;
> +	struct bpf_udp_iter_state *iter = priv_data;
> +	struct udp_iter_state *st = &iter->state;
>   	struct udp_seq_afinfo *afinfo;
>   	int ret;
>   
> @@ -3427,24 +3649,39 @@ static int bpf_iter_init_udp(void *priv_data, struct bpf_iter_aux_info *aux)
>   	afinfo->udp_table = NULL;
>   	st->bpf_seq_afinfo = afinfo;
>   	ret = bpf_iter_init_seq_net(priv_data, aux);
> -	if (ret)
> +	if (ret) {
>   		kfree(afinfo);
> +		return ret;
> +	}
> +	ret = bpf_iter_udp_realloc_batch(iter, INIT_BATCH_SZ);
> +	if (ret) {
> +		bpf_iter_fini_seq_net(priv_data);
> +		return ret;
> +	}
> +	iter->cur_sk = 0;
> +	iter->end_sk = 0;
> +	iter->st_bucket_done = false;
> +	st->bucket = 0;
> +	st->offset = 0;

 From looking at the tcp and unix counter part, I don't think this zeroings is 
necessary.

> +
>   	return ret;
>   }
>   
>   static void bpf_iter_fini_udp(void *priv_data)
>   {
> -	struct udp_iter_state *st = priv_data;
> +	struct bpf_udp_iter_state *iter = priv_data;
> +	struct udp_iter_state *st = &iter->state;
>   
> -	kfree(st->bpf_seq_afinfo);

The st->bpf_seq_afinfo should no longer be needed. Please remove it from 'struct 
udp_iter_state'.

The other AF_UNSPEC test in the existing udp_get_{first,next,...} should be 
cleaned up to use the refactored seq_sk_match() also.

These two changes should be done as the first one (or two?) cleanup patches 
before the actual udp batching patch. The tcp-iter-batching patch set could be a 
reference point on how the patch set could be structured.

>   	bpf_iter_fini_seq_net(priv_data);
> +	kfree(st->bpf_seq_afinfo);
> +	kvfree(iter->batch);
>   }
>   
>   static const struct bpf_iter_seq_info udp_seq_info = {
>   	.seq_ops		= &bpf_iter_udp_seq_ops,
>   	.init_seq_private	= bpf_iter_init_udp,
>   	.fini_seq_private	= bpf_iter_fini_udp,
> -	.seq_priv_size		= sizeof(struct udp_iter_state),
> +	.seq_priv_size		= sizeof(struct bpf_udp_iter_state),
>   };
>   
>   static struct bpf_iter_reg udp_reg_info = {
Aditi Ghag March 28, 2023, 5:06 p.m. UTC | #5
> On Mar 27, 2023, at 3:28 PM, Martin KaFai Lau <martin.lau@linux.dev> wrote:
> 
> On 3/23/23 1:06 PM, Aditi Ghag wrote:
>> Batch UDP sockets from BPF iterator that allows for overlapping locking
>> semantics in BPF/kernel helpers executed in BPF programs.  This facilitates
>> BPF socket destroy kfunc (introduced by follow-up patches) to execute from
>> BPF iterator programs.
>> Previously, BPF iterators acquired the sock lock and sockets hash table
>> bucket lock while executing BPF programs. This prevented BPF helpers that
>> again acquire these locks to be executed from BPF iterators.  With the
>> batching approach, we acquire a bucket lock, batch all the bucket sockets,
>> and then release the bucket lock. This enables BPF or kernel helpers to
>> skip sock locking when invoked in the supported BPF contexts.
>> The batching logic is similar to the logic implemented in TCP iterator:
>> https://lore.kernel.org/bpf/20210701200613.1036157-1-kafai@fb.com/.
>> Suggested-by: Martin KaFai Lau <martin.lau@kernel.org>
>> Signed-off-by: Aditi Ghag <aditi.ghag@isovalent.com>
>> ---
>>  include/net/udp.h |   1 +
>>  net/ipv4/udp.c    | 255 ++++++++++++++++++++++++++++++++++++++++++++--
>>  2 files changed, 247 insertions(+), 9 deletions(-)
>> diff --git a/include/net/udp.h b/include/net/udp.h
>> index de4b528522bb..d2999447d3f2 100644
>> --- a/include/net/udp.h
>> +++ b/include/net/udp.h
>> @@ -437,6 +437,7 @@ struct udp_seq_afinfo {
>>  struct udp_iter_state {
>>  	struct seq_net_private  p;
>>  	int			bucket;
>> +	int			offset;
> 
> offset should be moved to 'struct bpf_udp_iter_state' instead. It is specific to bpf_iter only.

Sure, I'll move it.

> 
>>  	struct udp_seq_afinfo	*bpf_seq_afinfo;
>>  };
>>  diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
>> index c605d171eb2d..58c620243e47 100644
>> --- a/net/ipv4/udp.c
>> +++ b/net/ipv4/udp.c
>> @@ -3152,6 +3152,171 @@ struct bpf_iter__udp {
>>  	int bucket __aligned(8);
>>  };
>>  +struct bpf_udp_iter_state {
>> +	struct udp_iter_state state;
>> +	unsigned int cur_sk;
>> +	unsigned int end_sk;
>> +	unsigned int max_sk;
>> +	struct sock **batch;
>> +	bool st_bucket_done;
>> +};
>> +
>> +static unsigned short seq_file_family(const struct seq_file *seq);
>> +static int bpf_iter_udp_realloc_batch(struct bpf_udp_iter_state *iter,
>> +				      unsigned int new_batch_sz);
>> +
>> +static inline bool seq_sk_match(struct seq_file *seq, const struct sock *sk)
>> +{
>> +	unsigned short family = seq_file_family(seq);
>> +
>> +	/* AF_UNSPEC is used as a match all */
>> +	return ((family == AF_UNSPEC || family == sk->sk_family) &&
>> +		net_eq(sock_net(sk), seq_file_net(seq)));
>> +}
>> +
>> +static struct sock *bpf_iter_udp_batch(struct seq_file *seq)
>> +{
>> +	struct bpf_udp_iter_state *iter = seq->private;
>> +	struct udp_iter_state *state = &iter->state;
>> +	struct net *net = seq_file_net(seq);
>> +	struct udp_seq_afinfo *afinfo = state->bpf_seq_afinfo;
>> +	struct udp_table *udptable;
>> +	struct sock *first_sk = NULL;
>> +	struct sock *sk;
>> +	unsigned int bucket_sks = 0;
>> +	bool resized = false;
>> +	int offset = 0;
>> +	int new_offset;
>> +
>> +	/* The current batch is done, so advance the bucket. */
>> +	if (iter->st_bucket_done) {
>> +		state->bucket++;
>> +		state->offset = 0;
>> +	}
>> +
>> +	udptable = udp_get_table_afinfo(afinfo, net);
>> +
>> +	if (state->bucket > udptable->mask) {
>> +		state->bucket = 0;
>> +		state->offset = 0;
>> +		return NULL;
>> +	}
>> +
>> +again:
>> +	/* New batch for the next bucket.
>> +	 * Iterate over the hash table to find a bucket with sockets matching
>> +	 * the iterator attributes, and return the first matching socket from
>> +	 * the bucket. The remaining matched sockets from the bucket are batched
>> +	 * before releasing the bucket lock. This allows BPF programs that are
>> +	 * called in seq_show to acquire the bucket lock if needed.
>> +	 */
>> +	iter->cur_sk = 0;
>> +	iter->end_sk = 0;
>> +	iter->st_bucket_done = false;
>> +	first_sk = NULL;
>> +	bucket_sks = 0;
>> +	offset = state->offset;
>> +	new_offset = offset;
>> +
>> +	for (; state->bucket <= udptable->mask; state->bucket++) {
>> +		struct udp_hslot *hslot = &udptable->hash[state->bucket];
> 
> Use udptable->hash"2" which is hashed by addr and port. It will help to get a smaller batch. It was the comment given in v2.

I thought I replied to your review comment, but looks like I didn't. My bad!
I already gave it a shot, and I'll need to understand better how udptable->hash2 is populated. When I swapped hash with hash2, there were no sockets to iterate. Am I missing something obvious? 

> 
>> +
>> +		if (hlist_empty(&hslot->head)) {
>> +			offset = 0;
>> +			continue;
>> +		}
>> +
>> +		spin_lock_bh(&hslot->lock);
>> +		/* Resume from the last saved position in a bucket before
>> +		 * iterator was stopped.
>> +		 */
>> +		while (offset-- > 0) {
>> +			sk_for_each(sk, &hslot->head)
>> +				continue;
>> +		}
> 
> hmm... how does the above while loop and sk_for_each loop actually work?
> 
>> +		sk_for_each(sk, &hslot->head) {
> 
> Here starts from the beginning of the hslot->head again. doesn't look right also.
> 
> Am I missing something here?
> 
>> +			if (seq_sk_match(seq, sk)) {
>> +				if (!first_sk)
>> +					first_sk = sk;
>> +				if (iter->end_sk < iter->max_sk) {
>> +					sock_hold(sk);
>> +					iter->batch[iter->end_sk++] = sk;
>> +				}
>> +				bucket_sks++;
>> +			}
>> +			new_offset++;
> 
> And this new_offset is outside of seq_sk_match, so it is not counting for the seq_file_net(seq) netns alone.

This logic to resume iterator is buggy, indeed! So I was trying to account for the cases where the current bucket could've been updated since we release the bucket lock. 
This is what I intended to do -

+loop:
                sk_for_each(sk, &hslot->head) {
                        if (seq_sk_match(seq, sk)) {
+                               /* Resume from the last saved position in the
+                                * bucket before iterator was stopped.
+                                */
+                               while (offset && offset-- > 0)
+                                       goto loop;
                                if (!first_sk)
                                        first_sk = sk;
                                if (iter->end_sk < iter->max_sk) {
@@ -3245,8 +3244,8 @@ static struct sock *bpf_iter_udp_batch(struct seq_file *seq)
                                        iter->batch[iter->end_sk++] = sk;
                                }
                                bucket_sks++;
+                              new_offset++;
                        }

This handles the case when sockets that weren't iterated in the previous round got deleted by the time iterator was resumed. But it's possible that previously iterated sockets got deleted before the iterator was later resumed, and the offset is now outdated. Ideally, iterator should be invalidated in this case, but there is no way to track this, is there? Any thoughts?  


> 
>> +		}
>> +		spin_unlock_bh(&hslot->lock);
>> +
>> +		if (first_sk)
>> +			break;
>> +
>> +		/* Reset the current bucket's offset before moving to the next bucket. */
>> +		offset = 0;
>> +		new_offset = 0;
>> +	}
>> +
>> +	/* All done: no batch made. */
>> +	if (!first_sk)
>> +		goto ret;
>> +
>> +	if (iter->end_sk == bucket_sks) {
>> +		/* Batching is done for the current bucket; return the first
>> +		 * socket to be iterated from the batch.
>> +		 */
>> +		iter->st_bucket_done = true;
>> +		goto ret;
>> +	}
>> +	if (!resized && !bpf_iter_udp_realloc_batch(iter, bucket_sks * 3 / 2)) {
>> +		resized = true;
>> +		/* Go back to the previous bucket to resize its batch. */
>> +		state->bucket--;
>> +		goto again;
>> +	}
>> +ret:
>> +	state->offset = new_offset;
>> +	return first_sk;
>> +}
>> +
>> +static void *bpf_iter_udp_seq_next(struct seq_file *seq, void *v, loff_t *pos)
>> +{
>> +	struct bpf_udp_iter_state *iter = seq->private;
>> +	struct udp_iter_state *state = &iter->state;
>> +	struct sock *sk;
>> +
>> +	/* Whenever seq_next() is called, the iter->cur_sk is
>> +	 * done with seq_show(), so unref the iter->cur_sk.
>> +	 */
>> +	if (iter->cur_sk < iter->end_sk) {
>> +		sock_put(iter->batch[iter->cur_sk++]);
>> +		++state->offset;
> 
> but then,
> if I read it correctly, this offset counting is only for netns specific to seq_file_net(seq) because batch is specific to seq_file_net(net). Is it going to work?
> 
>> +	}
>> +
>> +	/* After updating iter->cur_sk, check if there are more sockets
>> +	 * available in the current bucket batch.
>> +	 */
>> +	if (iter->cur_sk < iter->end_sk) {
>> +		sk = iter->batch[iter->cur_sk];
>> +	} else {
>> +		// Prepare a new batch.
>> +		sk = bpf_iter_udp_batch(seq);
>> +	}
>> +
>> +	++*pos;
>> +	return sk;
>> +}
>> +
>> +static void *bpf_iter_udp_seq_start(struct seq_file *seq, loff_t *pos)
>> +{
>> +	/* bpf iter does not support lseek, so it always
>> +	 * continue from where it was stop()-ped.
>> +	 */
>> +	if (*pos)
>> +		return bpf_iter_udp_batch(seq);
>> +
>> +	return SEQ_START_TOKEN;
>> +}
>> +
>>  static int udp_prog_seq_show(struct bpf_prog *prog, struct bpf_iter_meta *meta,
>>  			     struct udp_sock *udp_sk, uid_t uid, int bucket)
>>  {
>> @@ -3172,18 +3337,38 @@ static int bpf_iter_udp_seq_show(struct seq_file *seq, void *v)
>>  	struct bpf_prog *prog;
>>  	struct sock *sk = v;
>>  	uid_t uid;
>> +	bool slow;
>> +	int rc;
>>    	if (v == SEQ_START_TOKEN)
>>  		return 0;
>>  +	slow = lock_sock_fast(sk);
>> +
>> +	if (unlikely(sk_unhashed(sk))) {
>> +		rc = SEQ_SKIP;
>> +		goto unlock;
>> +	}
>> +
>>  	uid = from_kuid_munged(seq_user_ns(seq), sock_i_uid(sk));
>>  	meta.seq = seq;
>>  	prog = bpf_iter_get_info(&meta, false);
>> -	return udp_prog_seq_show(prog, &meta, v, uid, state->bucket);
>> +	rc = udp_prog_seq_show(prog, &meta, v, uid, state->bucket);
>> +
>> +unlock:
>> +	unlock_sock_fast(sk, slow);
>> +	return rc;
>> +}
>> +
>> +static void bpf_iter_udp_unref_batch(struct bpf_udp_iter_state *iter)
> 
> nit. Please use the same naming as in tcp-iter and unix-iter, so bpf_iter_udp_put_batch().

Ack
> 
>> +{
>> +	while (iter->cur_sk < iter->end_sk)
>> +		sock_put(iter->batch[iter->cur_sk++]);
>>  }
>>    static void bpf_iter_udp_seq_stop(struct seq_file *seq, void *v)
>>  {
>> +	struct bpf_udp_iter_state *iter = seq->private;
>>  	struct bpf_iter_meta meta;
>>  	struct bpf_prog *prog;
>>  @@ -3194,15 +3379,31 @@ static void bpf_iter_udp_seq_stop(struct seq_file *seq, void *v)
>>  			(void)udp_prog_seq_show(prog, &meta, v, 0, 0);
>>  	}
>>  -	udp_seq_stop(seq, v);
>> +	if (iter->cur_sk < iter->end_sk) {
>> +		bpf_iter_udp_unref_batch(iter);
>> +		iter->st_bucket_done = false;
>> +	}
>>  }
>>    static const struct seq_operations bpf_iter_udp_seq_ops = {
>> -	.start		= udp_seq_start,
>> -	.next		= udp_seq_next,
>> +	.start		= bpf_iter_udp_seq_start,
>> +	.next		= bpf_iter_udp_seq_next,
>>  	.stop		= bpf_iter_udp_seq_stop,
>>  	.show		= bpf_iter_udp_seq_show,
>>  };
>> +
>> +static unsigned short seq_file_family(const struct seq_file *seq)
>> +{
>> +	const struct udp_seq_afinfo *afinfo;
>> +
>> +	/* BPF iterator: bpf programs to filter sockets. */
>> +	if (seq->op == &bpf_iter_udp_seq_ops)
>> +		return AF_UNSPEC;
>> +
>> +	/* Proc fs iterator */
>> +	afinfo = pde_data(file_inode(seq->file));
>> +	return afinfo->family;
>> +}
>>  #endif
>>    const struct seq_operations udp_seq_ops = {
>> @@ -3413,9 +3614,30 @@ static struct pernet_operations __net_initdata udp_sysctl_ops = {
>>  DEFINE_BPF_ITER_FUNC(udp, struct bpf_iter_meta *meta,
>>  		     struct udp_sock *udp_sk, uid_t uid, int bucket)
>>  +static int bpf_iter_udp_realloc_batch(struct bpf_udp_iter_state *iter,
>> +				      unsigned int new_batch_sz)
>> +{
>> +	struct sock **new_batch;
>> +
>> +	new_batch = kvmalloc_array(new_batch_sz, sizeof(*new_batch),
>> +				   GFP_USER | __GFP_NOWARN);
>> +	if (!new_batch)
>> +		return -ENOMEM;
>> +
>> +	bpf_iter_udp_unref_batch(iter);
>> +	kvfree(iter->batch);
>> +	iter->batch = new_batch;
>> +	iter->max_sk = new_batch_sz;
>> +
>> +	return 0;
>> +}
>> +
>> +#define INIT_BATCH_SZ 16
>> +
>>  static int bpf_iter_init_udp(void *priv_data, struct bpf_iter_aux_info *aux)
>>  {
>> -	struct udp_iter_state *st = priv_data;
>> +	struct bpf_udp_iter_state *iter = priv_data;
>> +	struct udp_iter_state *st = &iter->state;
>>  	struct udp_seq_afinfo *afinfo;
>>  	int ret;
>>  @@ -3427,24 +3649,39 @@ static int bpf_iter_init_udp(void *priv_data, struct bpf_iter_aux_info *aux)
>>  	afinfo->udp_table = NULL;
>>  	st->bpf_seq_afinfo = afinfo;
>>  	ret = bpf_iter_init_seq_net(priv_data, aux);
>> -	if (ret)
>> +	if (ret) {
>>  		kfree(afinfo);
>> +		return ret;
>> +	}
>> +	ret = bpf_iter_udp_realloc_batch(iter, INIT_BATCH_SZ);
>> +	if (ret) {
>> +		bpf_iter_fini_seq_net(priv_data);
>> +		return ret;
>> +	}
>> +	iter->cur_sk = 0;
>> +	iter->end_sk = 0;
>> +	iter->st_bucket_done = false;
>> +	st->bucket = 0;
>> +	st->offset = 0;
> 
> From looking at the tcp and unix counter part, I don't think this zeroings is necessary.

Ack

> 
>> +
>>  	return ret;
>>  }
>>    static void bpf_iter_fini_udp(void *priv_data)
>>  {
>> -	struct udp_iter_state *st = priv_data;
>> +	struct bpf_udp_iter_state *iter = priv_data;
>> +	struct udp_iter_state *st = &iter->state;
>>  -	kfree(st->bpf_seq_afinfo);
> 
> The st->bpf_seq_afinfo should no longer be needed. Please remove it from 'struct udp_iter_state'.
> 
> The other AF_UNSPEC test in the existing udp_get_{first,next,...} should be cleaned up to use the refactored seq_sk_match() also.
> 
> These two changes should be done as the first one (or two?) cleanup patches before the actual udp batching patch. The tcp-iter-batching patch set could be a reference point on how the patch set could be structured.

Ack for both the clean-up and reshuffling. 

> 
>>  	bpf_iter_fini_seq_net(priv_data);
>> +	kfree(st->bpf_seq_afinfo);
>> +	kvfree(iter->batch);
>>  }
>>    static const struct bpf_iter_seq_info udp_seq_info = {
>>  	.seq_ops		= &bpf_iter_udp_seq_ops,
>>  	.init_seq_private	= bpf_iter_init_udp,
>>  	.fini_seq_private	= bpf_iter_fini_udp,
>> -	.seq_priv_size		= sizeof(struct udp_iter_state),
>> +	.seq_priv_size		= sizeof(struct bpf_udp_iter_state),
>>  };
>>    static struct bpf_iter_reg udp_reg_info = {
Martin KaFai Lau March 28, 2023, 9:33 p.m. UTC | #6
On 3/28/23 10:06 AM, Aditi Ghag wrote:
>>> +static struct sock *bpf_iter_udp_batch(struct seq_file *seq)
>>> +{
>>> +	struct bpf_udp_iter_state *iter = seq->private;
>>> +	struct udp_iter_state *state = &iter->state;
>>> +	struct net *net = seq_file_net(seq);
>>> +	struct udp_seq_afinfo *afinfo = state->bpf_seq_afinfo;
>>> +	struct udp_table *udptable;
>>> +	struct sock *first_sk = NULL;
>>> +	struct sock *sk;
>>> +	unsigned int bucket_sks = 0;
>>> +	bool resized = false;
>>> +	int offset = 0;
>>> +	int new_offset;
>>> +
>>> +	/* The current batch is done, so advance the bucket. */
>>> +	if (iter->st_bucket_done) {
>>> +		state->bucket++;
>>> +		state->offset = 0;
>>> +	}
>>> +
>>> +	udptable = udp_get_table_afinfo(afinfo, net);
>>> +
>>> +	if (state->bucket > udptable->mask) {
>>> +		state->bucket = 0;
>>> +		state->offset = 0;
>>> +		return NULL;
>>> +	}
>>> +
>>> +again:
>>> +	/* New batch for the next bucket.
>>> +	 * Iterate over the hash table to find a bucket with sockets matching
>>> +	 * the iterator attributes, and return the first matching socket from
>>> +	 * the bucket. The remaining matched sockets from the bucket are batched
>>> +	 * before releasing the bucket lock. This allows BPF programs that are
>>> +	 * called in seq_show to acquire the bucket lock if needed.
>>> +	 */
>>> +	iter->cur_sk = 0;
>>> +	iter->end_sk = 0;
>>> +	iter->st_bucket_done = false;
>>> +	first_sk = NULL;
>>> +	bucket_sks = 0;
>>> +	offset = state->offset;
>>> +	new_offset = offset;
>>> +
>>> +	for (; state->bucket <= udptable->mask; state->bucket++) {
>>> +		struct udp_hslot *hslot = &udptable->hash[state->bucket];
>>
>> Use udptable->hash"2" which is hashed by addr and port. It will help to get a smaller batch. It was the comment given in v2.
> 
> I thought I replied to your review comment, but looks like I didn't. My bad!
> I already gave it a shot, and I'll need to understand better how udptable->hash2 is populated. When I swapped hash with hash2, there were no sockets to iterate. Am I missing something obvious?

Take a look at udp_lib_lport_inuse2() on how it iterates.

> 
>>
>>> +
>>> +		if (hlist_empty(&hslot->head)) {
>>> +			offset = 0;
>>> +			continue;
>>> +		}
>>> +
>>> +		spin_lock_bh(&hslot->lock);
>>> +		/* Resume from the last saved position in a bucket before
>>> +		 * iterator was stopped.
>>> +		 */
>>> +		while (offset-- > 0) {
>>> +			sk_for_each(sk, &hslot->head)
>>> +				continue;
>>> +		}
>>
>> hmm... how does the above while loop and sk_for_each loop actually work?
>>
>>> +		sk_for_each(sk, &hslot->head) {
>>
>> Here starts from the beginning of the hslot->head again. doesn't look right also.
>>
>> Am I missing something here?
>>
>>> +			if (seq_sk_match(seq, sk)) {
>>> +				if (!first_sk)
>>> +					first_sk = sk;
>>> +				if (iter->end_sk < iter->max_sk) {
>>> +					sock_hold(sk);
>>> +					iter->batch[iter->end_sk++] = sk;
>>> +				}
>>> +				bucket_sks++;
>>> +			}
>>> +			new_offset++;
>>
>> And this new_offset is outside of seq_sk_match, so it is not counting for the seq_file_net(seq) netns alone.
> 
> This logic to resume iterator is buggy, indeed! So I was trying to account for the cases where the current bucket could've been updated since we release the bucket lock.
> This is what I intended to do -
> 
> +loop:
>                  sk_for_each(sk, &hslot->head) {
>                          if (seq_sk_match(seq, sk)) {
> +                               /* Resume from the last saved position in the
> +                                * bucket before iterator was stopped.
> +                                */
> +                               while (offset && offset-- > 0)
> +                                       goto loop;

still does not look right. merely a loop decrementing offset one at a time and 
then go back to the beginning of hslot->head?

A quick (untested and uncompiled) thought :

				/* Skip the first 'offset' number of sk
				 * and not putting them in the iter->batch[].
				 */
				if (offset) {
					offset--;
					continue;
				}

>                                  if (!first_sk)
>                                          first_sk = sk;
>                                  if (iter->end_sk < iter->max_sk) {
> @@ -3245,8 +3244,8 @@ static struct sock *bpf_iter_udp_batch(struct seq_file *seq)
>                                          iter->batch[iter->end_sk++] = sk;
>                                  }
>                                  bucket_sks++ > +                              new_offset++;
>                          }
> 
> This handles the case when sockets that weren't iterated in the previous round got deleted by the time iterator was resumed. But it's possible that previously iterated sockets got deleted before the iterator was later resumed, and the offset is now outdated. Ideally, iterator should be invalidated in this case, but there is no way to track this, is there? Any thoughts?

I would not worry about this update in-between case. race will happen anyway 
when the bucket lock is released. This should be very unlikely when hash"2" is used.
Aditi Ghag March 29, 2023, 4:20 p.m. UTC | #7
> On Mar 28, 2023, at 2:33 PM, Martin KaFai Lau <martin.lau@linux.dev> wrote:
> 
> On 3/28/23 10:06 AM, Aditi Ghag wrote:
>>>> +static struct sock *bpf_iter_udp_batch(struct seq_file *seq)
>>>> +{
>>>> +	struct bpf_udp_iter_state *iter = seq->private;
>>>> +	struct udp_iter_state *state = &iter->state;
>>>> +	struct net *net = seq_file_net(seq);
>>>> +	struct udp_seq_afinfo *afinfo = state->bpf_seq_afinfo;
>>>> +	struct udp_table *udptable;
>>>> +	struct sock *first_sk = NULL;
>>>> +	struct sock *sk;
>>>> +	unsigned int bucket_sks = 0;
>>>> +	bool resized = false;
>>>> +	int offset = 0;
>>>> +	int new_offset;
>>>> +
>>>> +	/* The current batch is done, so advance the bucket. */
>>>> +	if (iter->st_bucket_done) {
>>>> +		state->bucket++;
>>>> +		state->offset = 0;
>>>> +	}
>>>> +
>>>> +	udptable = udp_get_table_afinfo(afinfo, net);
>>>> +
>>>> +	if (state->bucket > udptable->mask) {
>>>> +		state->bucket = 0;
>>>> +		state->offset = 0;
>>>> +		return NULL;
>>>> +	}
>>>> +
>>>> +again:
>>>> +	/* New batch for the next bucket.
>>>> +	 * Iterate over the hash table to find a bucket with sockets matching
>>>> +	 * the iterator attributes, and return the first matching socket from
>>>> +	 * the bucket. The remaining matched sockets from the bucket are batched
>>>> +	 * before releasing the bucket lock. This allows BPF programs that are
>>>> +	 * called in seq_show to acquire the bucket lock if needed.
>>>> +	 */
>>>> +	iter->cur_sk = 0;
>>>> +	iter->end_sk = 0;
>>>> +	iter->st_bucket_done = false;
>>>> +	first_sk = NULL;
>>>> +	bucket_sks = 0;
>>>> +	offset = state->offset;
>>>> +	new_offset = offset;
>>>> +
>>>> +	for (; state->bucket <= udptable->mask; state->bucket++) {
>>>> +		struct udp_hslot *hslot = &udptable->hash[state->bucket];
>>> 
>>> Use udptable->hash"2" which is hashed by addr and port. It will help to get a smaller batch. It was the comment given in v2.
>> I thought I replied to your review comment, but looks like I didn't. My bad!
>> I already gave it a shot, and I'll need to understand better how udptable->hash2 is populated. When I swapped hash with hash2, there were no sockets to iterate. Am I missing something obvious?
> 
> Take a look at udp_lib_lport_inuse2() on how it iterates.

Thanks! I've updated the code to use hash2 instead of hash.

> 
>>> 
>>>> +
>>>> +		if (hlist_empty(&hslot->head)) {
>>>> +			offset = 0;
>>>> +			continue;
>>>> +		}
>>>> +
>>>> +		spin_lock_bh(&hslot->lock);
>>>> +		/* Resume from the last saved position in a bucket before
>>>> +		 * iterator was stopped.
>>>> +		 */
>>>> +		while (offset-- > 0) {
>>>> +			sk_for_each(sk, &hslot->head)
>>>> +				continue;
>>>> +		}
>>> 
>>> hmm... how does the above while loop and sk_for_each loop actually work?
>>> 
>>>> +		sk_for_each(sk, &hslot->head) {
>>> 
>>> Here starts from the beginning of the hslot->head again. doesn't look right also.
>>> 
>>> Am I missing something here?
>>> 
>>>> +			if (seq_sk_match(seq, sk)) {
>>>> +				if (!first_sk)
>>>> +					first_sk = sk;
>>>> +				if (iter->end_sk < iter->max_sk) {
>>>> +					sock_hold(sk);
>>>> +					iter->batch[iter->end_sk++] = sk;
>>>> +				}
>>>> +				bucket_sks++;
>>>> +			}
>>>> +			new_offset++;
>>> 
>>> And this new_offset is outside of seq_sk_match, so it is not counting for the seq_file_net(seq) netns alone.
>> This logic to resume iterator is buggy, indeed! So I was trying to account for the cases where the current bucket could've been updated since we release the bucket lock.
>> This is what I intended to do -
>> +loop:
>>                 sk_for_each(sk, &hslot->head) {
>>                         if (seq_sk_match(seq, sk)) {
>> +                               /* Resume from the last saved position in the
>> +                                * bucket before iterator was stopped.
>> +                                */
>> +                               while (offset && offset-- > 0)
>> +                                       goto loop;
> 
> still does not look right. merely a loop decrementing offset one at a time and then go back to the beginning of hslot->head?

Yes, I realized that the macro doesn't continue as I thought it would. I've fixed it.

> 
> A quick (untested and uncompiled) thought :
> 
> 				/* Skip the first 'offset' number of sk
> 				 * and not putting them in the iter->batch[].
> 				 */
> 				if (offset) {
> 					offset--;
> 					continue;
> 				}
> 
>>                                 if (!first_sk)
>>                                         first_sk = sk;
>>                                 if (iter->end_sk < iter->max_sk) {
>> @@ -3245,8 +3244,8 @@ static struct sock *bpf_iter_udp_batch(struct seq_file *seq)
>>                                         iter->batch[iter->end_sk++] = sk;
>>                                 }
>>                                 bucket_sks++ > +                              new_offset++;
>>                         }
>> This handles the case when sockets that weren't iterated in the previous round got deleted by the time iterator was resumed. But it's possible that previously iterated sockets got deleted before the iterator was later resumed, and the offset is now outdated. Ideally, iterator should be invalidated in this case, but there is no way to track this, is there? Any thoughts?
> 
> I would not worry about this update in-between case. race will happen anyway when the bucket lock is released. This should be very unlikely when hash"2" is used.
> 
> 

That makes sense.
diff mbox series

Patch

diff --git a/include/net/udp.h b/include/net/udp.h
index de4b528522bb..d2999447d3f2 100644
--- a/include/net/udp.h
+++ b/include/net/udp.h
@@ -437,6 +437,7 @@  struct udp_seq_afinfo {
 struct udp_iter_state {
 	struct seq_net_private  p;
 	int			bucket;
+	int			offset;
 	struct udp_seq_afinfo	*bpf_seq_afinfo;
 };
 
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index c605d171eb2d..58c620243e47 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -3152,6 +3152,171 @@  struct bpf_iter__udp {
 	int bucket __aligned(8);
 };
 
+struct bpf_udp_iter_state {
+	struct udp_iter_state state;
+	unsigned int cur_sk;
+	unsigned int end_sk;
+	unsigned int max_sk;
+	struct sock **batch;
+	bool st_bucket_done;
+};
+
+static unsigned short seq_file_family(const struct seq_file *seq);
+static int bpf_iter_udp_realloc_batch(struct bpf_udp_iter_state *iter,
+				      unsigned int new_batch_sz);
+
+static inline bool seq_sk_match(struct seq_file *seq, const struct sock *sk)
+{
+	unsigned short family = seq_file_family(seq);
+
+	/* AF_UNSPEC is used as a match all */
+	return ((family == AF_UNSPEC || family == sk->sk_family) &&
+		net_eq(sock_net(sk), seq_file_net(seq)));
+}
+
+static struct sock *bpf_iter_udp_batch(struct seq_file *seq)
+{
+	struct bpf_udp_iter_state *iter = seq->private;
+	struct udp_iter_state *state = &iter->state;
+	struct net *net = seq_file_net(seq);
+	struct udp_seq_afinfo *afinfo = state->bpf_seq_afinfo;
+	struct udp_table *udptable;
+	struct sock *first_sk = NULL;
+	struct sock *sk;
+	unsigned int bucket_sks = 0;
+	bool resized = false;
+	int offset = 0;
+	int new_offset;
+
+	/* The current batch is done, so advance the bucket. */
+	if (iter->st_bucket_done) {
+		state->bucket++;
+		state->offset = 0;
+	}
+
+	udptable = udp_get_table_afinfo(afinfo, net);
+
+	if (state->bucket > udptable->mask) {
+		state->bucket = 0;
+		state->offset = 0;
+		return NULL;
+	}
+
+again:
+	/* New batch for the next bucket.
+	 * Iterate over the hash table to find a bucket with sockets matching
+	 * the iterator attributes, and return the first matching socket from
+	 * the bucket. The remaining matched sockets from the bucket are batched
+	 * before releasing the bucket lock. This allows BPF programs that are
+	 * called in seq_show to acquire the bucket lock if needed.
+	 */
+	iter->cur_sk = 0;
+	iter->end_sk = 0;
+	iter->st_bucket_done = false;
+	first_sk = NULL;
+	bucket_sks = 0;
+	offset = state->offset;
+	new_offset = offset;
+
+	for (; state->bucket <= udptable->mask; state->bucket++) {
+		struct udp_hslot *hslot = &udptable->hash[state->bucket];
+
+		if (hlist_empty(&hslot->head)) {
+			offset = 0;
+			continue;
+		}
+
+		spin_lock_bh(&hslot->lock);
+		/* Resume from the last saved position in a bucket before
+		 * iterator was stopped.
+		 */
+		while (offset-- > 0) {
+			sk_for_each(sk, &hslot->head)
+				continue;
+		}
+		sk_for_each(sk, &hslot->head) {
+			if (seq_sk_match(seq, sk)) {
+				if (!first_sk)
+					first_sk = sk;
+				if (iter->end_sk < iter->max_sk) {
+					sock_hold(sk);
+					iter->batch[iter->end_sk++] = sk;
+				}
+				bucket_sks++;
+			}
+			new_offset++;
+		}
+		spin_unlock_bh(&hslot->lock);
+
+		if (first_sk)
+			break;
+
+		/* Reset the current bucket's offset before moving to the next bucket. */
+		offset = 0;
+		new_offset = 0;
+	}
+
+	/* All done: no batch made. */
+	if (!first_sk)
+		goto ret;
+
+	if (iter->end_sk == bucket_sks) {
+		/* Batching is done for the current bucket; return the first
+		 * socket to be iterated from the batch.
+		 */
+		iter->st_bucket_done = true;
+		goto ret;
+	}
+	if (!resized && !bpf_iter_udp_realloc_batch(iter, bucket_sks * 3 / 2)) {
+		resized = true;
+		/* Go back to the previous bucket to resize its batch. */
+		state->bucket--;
+		goto again;
+	}
+ret:
+	state->offset = new_offset;
+	return first_sk;
+}
+
+static void *bpf_iter_udp_seq_next(struct seq_file *seq, void *v, loff_t *pos)
+{
+	struct bpf_udp_iter_state *iter = seq->private;
+	struct udp_iter_state *state = &iter->state;
+	struct sock *sk;
+
+	/* Whenever seq_next() is called, the iter->cur_sk is
+	 * done with seq_show(), so unref the iter->cur_sk.
+	 */
+	if (iter->cur_sk < iter->end_sk) {
+		sock_put(iter->batch[iter->cur_sk++]);
+		++state->offset;
+	}
+
+	/* After updating iter->cur_sk, check if there are more sockets
+	 * available in the current bucket batch.
+	 */
+	if (iter->cur_sk < iter->end_sk) {
+		sk = iter->batch[iter->cur_sk];
+	} else {
+		// Prepare a new batch.
+		sk = bpf_iter_udp_batch(seq);
+	}
+
+	++*pos;
+	return sk;
+}
+
+static void *bpf_iter_udp_seq_start(struct seq_file *seq, loff_t *pos)
+{
+	/* bpf iter does not support lseek, so it always
+	 * continue from where it was stop()-ped.
+	 */
+	if (*pos)
+		return bpf_iter_udp_batch(seq);
+
+	return SEQ_START_TOKEN;
+}
+
 static int udp_prog_seq_show(struct bpf_prog *prog, struct bpf_iter_meta *meta,
 			     struct udp_sock *udp_sk, uid_t uid, int bucket)
 {
@@ -3172,18 +3337,38 @@  static int bpf_iter_udp_seq_show(struct seq_file *seq, void *v)
 	struct bpf_prog *prog;
 	struct sock *sk = v;
 	uid_t uid;
+	bool slow;
+	int rc;
 
 	if (v == SEQ_START_TOKEN)
 		return 0;
 
+	slow = lock_sock_fast(sk);
+
+	if (unlikely(sk_unhashed(sk))) {
+		rc = SEQ_SKIP;
+		goto unlock;
+	}
+
 	uid = from_kuid_munged(seq_user_ns(seq), sock_i_uid(sk));
 	meta.seq = seq;
 	prog = bpf_iter_get_info(&meta, false);
-	return udp_prog_seq_show(prog, &meta, v, uid, state->bucket);
+	rc = udp_prog_seq_show(prog, &meta, v, uid, state->bucket);
+
+unlock:
+	unlock_sock_fast(sk, slow);
+	return rc;
+}
+
+static void bpf_iter_udp_unref_batch(struct bpf_udp_iter_state *iter)
+{
+	while (iter->cur_sk < iter->end_sk)
+		sock_put(iter->batch[iter->cur_sk++]);
 }
 
 static void bpf_iter_udp_seq_stop(struct seq_file *seq, void *v)
 {
+	struct bpf_udp_iter_state *iter = seq->private;
 	struct bpf_iter_meta meta;
 	struct bpf_prog *prog;
 
@@ -3194,15 +3379,31 @@  static void bpf_iter_udp_seq_stop(struct seq_file *seq, void *v)
 			(void)udp_prog_seq_show(prog, &meta, v, 0, 0);
 	}
 
-	udp_seq_stop(seq, v);
+	if (iter->cur_sk < iter->end_sk) {
+		bpf_iter_udp_unref_batch(iter);
+		iter->st_bucket_done = false;
+	}
 }
 
 static const struct seq_operations bpf_iter_udp_seq_ops = {
-	.start		= udp_seq_start,
-	.next		= udp_seq_next,
+	.start		= bpf_iter_udp_seq_start,
+	.next		= bpf_iter_udp_seq_next,
 	.stop		= bpf_iter_udp_seq_stop,
 	.show		= bpf_iter_udp_seq_show,
 };
+
+static unsigned short seq_file_family(const struct seq_file *seq)
+{
+	const struct udp_seq_afinfo *afinfo;
+
+	/* BPF iterator: bpf programs to filter sockets. */
+	if (seq->op == &bpf_iter_udp_seq_ops)
+		return AF_UNSPEC;
+
+	/* Proc fs iterator */
+	afinfo = pde_data(file_inode(seq->file));
+	return afinfo->family;
+}
 #endif
 
 const struct seq_operations udp_seq_ops = {
@@ -3413,9 +3614,30 @@  static struct pernet_operations __net_initdata udp_sysctl_ops = {
 DEFINE_BPF_ITER_FUNC(udp, struct bpf_iter_meta *meta,
 		     struct udp_sock *udp_sk, uid_t uid, int bucket)
 
+static int bpf_iter_udp_realloc_batch(struct bpf_udp_iter_state *iter,
+				      unsigned int new_batch_sz)
+{
+	struct sock **new_batch;
+
+	new_batch = kvmalloc_array(new_batch_sz, sizeof(*new_batch),
+				   GFP_USER | __GFP_NOWARN);
+	if (!new_batch)
+		return -ENOMEM;
+
+	bpf_iter_udp_unref_batch(iter);
+	kvfree(iter->batch);
+	iter->batch = new_batch;
+	iter->max_sk = new_batch_sz;
+
+	return 0;
+}
+
+#define INIT_BATCH_SZ 16
+
 static int bpf_iter_init_udp(void *priv_data, struct bpf_iter_aux_info *aux)
 {
-	struct udp_iter_state *st = priv_data;
+	struct bpf_udp_iter_state *iter = priv_data;
+	struct udp_iter_state *st = &iter->state;
 	struct udp_seq_afinfo *afinfo;
 	int ret;
 
@@ -3427,24 +3649,39 @@  static int bpf_iter_init_udp(void *priv_data, struct bpf_iter_aux_info *aux)
 	afinfo->udp_table = NULL;
 	st->bpf_seq_afinfo = afinfo;
 	ret = bpf_iter_init_seq_net(priv_data, aux);
-	if (ret)
+	if (ret) {
 		kfree(afinfo);
+		return ret;
+	}
+	ret = bpf_iter_udp_realloc_batch(iter, INIT_BATCH_SZ);
+	if (ret) {
+		bpf_iter_fini_seq_net(priv_data);
+		return ret;
+	}
+	iter->cur_sk = 0;
+	iter->end_sk = 0;
+	iter->st_bucket_done = false;
+	st->bucket = 0;
+	st->offset = 0;
+
 	return ret;
 }
 
 static void bpf_iter_fini_udp(void *priv_data)
 {
-	struct udp_iter_state *st = priv_data;
+	struct bpf_udp_iter_state *iter = priv_data;
+	struct udp_iter_state *st = &iter->state;
 
-	kfree(st->bpf_seq_afinfo);
 	bpf_iter_fini_seq_net(priv_data);
+	kfree(st->bpf_seq_afinfo);
+	kvfree(iter->batch);
 }
 
 static const struct bpf_iter_seq_info udp_seq_info = {
 	.seq_ops		= &bpf_iter_udp_seq_ops,
 	.init_seq_private	= bpf_iter_init_udp,
 	.fini_seq_private	= bpf_iter_fini_udp,
-	.seq_priv_size		= sizeof(struct udp_iter_state),
+	.seq_priv_size		= sizeof(struct bpf_udp_iter_state),
 };
 
 static struct bpf_iter_reg udp_reg_info = {