diff mbox series

[RFC,net-next,v2,2/5] net: psample: add multicast filtering on group_id

Message ID 20240408125753.470419-3-amorenoz@redhat.com (mailing list archive)
State RFC
Delegated to: Netdev Maintainers
Headers show
Series net: openvswitch: Add sample multicasting. | expand

Checks

Context Check Description
netdev/series_format success Posting correctly formatted
netdev/tree_selection success Clearly marked for net-next
netdev/ynl success Generated files up to date; no warnings/errors; no diff in generated;
netdev/fixes_present success Fixes tag not required for -next series
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 951 this patch: 951
netdev/build_tools success Errors and warnings before: 0 this patch: 0
netdev/cc_maintainers warning 3 maintainers not CCed: pabeni@redhat.com kuba@kernel.org edumazet@google.com
netdev/build_clang success Errors and warnings before: 954 this patch: 954
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/deprecated_api success None detected
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success No Fixes tag
netdev/build_allmodconfig_warn success Errors and warnings before: 962 this patch: 962
netdev/checkpatch success total: 0 errors, 0 warnings, 0 checks, 171 lines checked
netdev/build_clang_rust success No Rust files in patch. Skipping build
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/source_inline fail Was 0 now: 1

Commit Message

Adrián Moreno April 8, 2024, 12:57 p.m. UTC
Packet samples can come from several places (e.g: different tc sample
actions), typically using the sample group (PSAMPLE_ATTR_SAMPLE_GROUP)
to differentiate them.

Likewise, sample consumers that listen on the multicast group may only
be interested on a single group. However, they are currently forced to
receive all samples and discard the ones that are not relevant, causing
unnecessary overhead.

Allow users to filter on the desired group_id by adding a new command
SAMPLE_FILTER_SET that can be used to pass the desired group id.
Store this filter on the per-socket private pointer and use it for
filtering multicasted samples.

Signed-off-by: Adrian Moreno <amorenoz@redhat.com>
---
 include/uapi/linux/psample.h |   1 +
 net/psample/psample.c        | 127 +++++++++++++++++++++++++++++++++--
 2 files changed, 122 insertions(+), 6 deletions(-)

Comments

Ilya Maximets April 8, 2024, 1:18 p.m. UTC | #1
[copying my previous reply since this version actually has netdev@ in Cc]

On 4/8/24 14:57, Adrian Moreno wrote:
> Packet samples can come from several places (e.g: different tc sample
> actions), typically using the sample group (PSAMPLE_ATTR_SAMPLE_GROUP)
> to differentiate them.
> 
> Likewise, sample consumers that listen on the multicast group may only
> be interested on a single group. However, they are currently forced to
> receive all samples and discard the ones that are not relevant, causing
> unnecessary overhead.
> 
> Allow users to filter on the desired group_id by adding a new command
> SAMPLE_FILTER_SET that can be used to pass the desired group id.
> Store this filter on the per-socket private pointer and use it for
> filtering multicasted samples.
> 
> Signed-off-by: Adrian Moreno <amorenoz@redhat.com>
> ---
>  include/uapi/linux/psample.h |   1 +
>  net/psample/psample.c        | 127 +++++++++++++++++++++++++++++++++--
>  2 files changed, 122 insertions(+), 6 deletions(-)
> 
> diff --git a/include/uapi/linux/psample.h b/include/uapi/linux/psample.h
> index e585db5bf2d2..5e0305b1520d 100644
> --- a/include/uapi/linux/psample.h
> +++ b/include/uapi/linux/psample.h
> @@ -28,6 +28,7 @@ enum psample_command {
>  	PSAMPLE_CMD_GET_GROUP,
>  	PSAMPLE_CMD_NEW_GROUP,
>  	PSAMPLE_CMD_DEL_GROUP,
> +	PSAMPLE_CMD_SAMPLE_FILTER_SET,
Other commands are names as PSAMPLE_CMD_VERB_NOUN, so this new one
should be PSAMPLE_CMD_SET_FILTER.  (The SAMPLE part seems unnecessary.)
Some functions/structures need to be renamed accordingly.

>  };
>  
>  enum psample_tunnel_key_attr {
> diff --git a/net/psample/psample.c b/net/psample/psample.c
> index a5d9b8446f77..a0cef63dfdec 100644
> --- a/net/psample/psample.c
> +++ b/net/psample/psample.c
> @@ -98,13 +98,84 @@ static int psample_nl_cmd_get_group_dumpit(struct sk_buff *msg,
>  	return msg->len;
>  }
>  
> -static const struct genl_small_ops psample_nl_ops[] = {
> +struct psample_obj_desc {
> +	struct rcu_head rcu;
> +	u32 group_num;
> +	bool group_num_valid;
> +};
> +
> +struct psample_nl_sock_priv {
> +	struct psample_obj_desc __rcu *flt;

Can we call it 'fileter' ?  I find it hard to read the code with
this unnecessary abbreviation.  Same for the lock below.

> +	spinlock_t flt_lock; /* Protects flt. */
> +};
> +
> +static void psample_nl_sock_priv_init(void *priv)
> +{
> +	struct psample_nl_sock_priv *sk_priv = priv;
> +
> +	spin_lock_init(&sk_priv->flt_lock);
> +}
> +
> +static void psample_nl_sock_priv_destroy(void *priv)
> +{
> +	struct psample_nl_sock_priv *sk_priv = priv;
> +	struct psample_obj_desc *flt;
> +
> +	flt = rcu_dereference_protected(sk_priv->flt, true);
> +	kfree_rcu(flt, rcu);
> +}
> +
> +static int psample_nl_sample_filter_set_doit(struct sk_buff *skb,
> +					     struct genl_info *info)
> +{
> +	struct psample_nl_sock_priv *sk_priv;
> +	struct nlattr **attrs = info->attrs;
> +	struct psample_obj_desc *flt;
> +
> +	flt = kzalloc(sizeof(*flt), GFP_KERNEL);
> +
> +	if (attrs[PSAMPLE_ATTR_SAMPLE_GROUP]) {
> +		flt->group_num = nla_get_u32(attrs[PSAMPLE_ATTR_SAMPLE_GROUP]);
> +		flt->group_num_valid = true;
> +	}
> +
> +	if (!flt->group_num_valid) {
> +		kfree(flt);

Might be better to not allocate it in the first place.

> +		flt = NULL;
> +	}
> +
> +	sk_priv = genl_sk_priv_get(&psample_nl_family, NETLINK_CB(skb).sk);
> +	if (IS_ERR(sk_priv)) {
> +		kfree(flt);
> +		return PTR_ERR(sk_priv);
> +	}
> +
> +	spin_lock(&sk_priv->flt_lock);
> +	flt = rcu_replace_pointer(sk_priv->flt, flt,
> +				  lockdep_is_held(&sk_priv->flt_lock));
> +	spin_unlock(&sk_priv->flt_lock);
> +	kfree_rcu(flt, rcu);
> +	return 0;
> +}
> +
> +static const struct nla_policy
> +	psample_sample_filter_set_policy[PSAMPLE_ATTR_SAMPLE_GROUP + 1] = {
> +	[PSAMPLE_ATTR_SAMPLE_GROUP] = { .type = NLA_U32, },

This indentation is confusing, though I'm not sure what's a better way.

> +};
> +
> +static const struct genl_ops psample_nl_ops[] = {
>  	{
>  		.cmd = PSAMPLE_CMD_GET_GROUP,
>  		.validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP,
>  		.dumpit = psample_nl_cmd_get_group_dumpit,
>  		/* can be retrieved by unprivileged users */
> -	}
> +	},
> +	{
> +		.cmd		= PSAMPLE_CMD_SAMPLE_FILTER_SET,
> +		.doit		= psample_nl_sample_filter_set_doit,
> +		.policy		= psample_sample_filter_set_policy,
> +		.flags		= 0,
> +	},
>  };
>  
>  static struct genl_family psample_nl_family __ro_after_init = {
> @@ -114,10 +185,13 @@ static struct genl_family psample_nl_family __ro_after_init = {
>  	.netnsok	= true,
>  	.module		= THIS_MODULE,
>  	.mcgrps		= psample_nl_mcgrps,
> -	.small_ops	= psample_nl_ops,
> -	.n_small_ops	= ARRAY_SIZE(psample_nl_ops),
> +	.ops		= psample_nl_ops,
> +	.n_ops		= ARRAY_SIZE(psample_nl_ops),
>  	.resv_start_op	= PSAMPLE_CMD_GET_GROUP + 1,
>  	.n_mcgrps	= ARRAY_SIZE(psample_nl_mcgrps),
> +	.sock_priv_size		= sizeof(struct psample_nl_sock_priv),
> +	.sock_priv_init		= psample_nl_sock_priv_init,
> +	.sock_priv_destroy	= psample_nl_sock_priv_destroy,
>  };
>  
>  static void psample_group_notify(struct psample_group *group,
> @@ -360,6 +434,42 @@ static int psample_tunnel_meta_len(struct ip_tunnel_info *tun_info)
>  }
>  #endif
>  
> +static inline void psample_nl_obj_desc_init(struct psample_obj_desc *desc,
> +					    u32 group_num)
> +{
> +	memset(desc, 0, sizeof(*desc));
> +	desc->group_num = group_num;
> +	desc->group_num_valid = true;
> +}
> +
> +static bool psample_obj_desc_match(struct psample_obj_desc *desc,
> +				   struct psample_obj_desc *flt)
> +{
> +	if (desc->group_num_valid && flt->group_num_valid &&
> +	    desc->group_num != flt->group_num)
> +		return false;
> +	return true;

This fucntion returns 'true' if one of the arguments is not valid.
I'd not expect such behavior from a 'match' function.

I understand the intention that psample should sample everything
to sockets that do not request filters, but that should not be part
of the 'match' logic, or more appropriate function name should be
chosen.  Also, if the group is not initialized, but the filter is,
it should not match, logically.  The validity on filter and the
current sample is not symmetric.

And I'm not really sure if the 'group_num_valid' is actually needed.
Can the NULL pointer be used as an indicator?  If so, then maybe
the whole psample_obj_desc structure is not needed as it will
contain a single field.

> +}
> +
> +static int psample_nl_sample_filter(struct sock *dsk, struct sk_buff *skb,
> +				    void *data)
> +{
> +	struct psample_obj_desc *desc = data;
> +	struct psample_nl_sock_priv *sk_priv;
> +	struct psample_obj_desc *flt;
> +	int ret = 0;
> +
> +	rcu_read_lock();
> +	sk_priv = __genl_sk_priv_get(&psample_nl_family, dsk);
> +	if (!IS_ERR_OR_NULL(sk_priv)) {
> +		flt = rcu_dereference(sk_priv->flt);
> +		if (flt)
> +			ret = !psample_obj_desc_match(desc, flt);
> +	}
> +	rcu_read_unlock();
> +	return ret;
> +}
> +
>  void psample_sample_packet(struct psample_group *group, struct sk_buff *skb,
>  			   u32 sample_rate, const struct psample_metadata *md)
>  {
> @@ -370,6 +480,7 @@ void psample_sample_packet(struct psample_group *group, struct sk_buff *skb,
>  #ifdef CONFIG_INET
>  	struct ip_tunnel_info *tun_info;
>  #endif
> +	struct psample_obj_desc desc;
>  	struct sk_buff *nl_skb;
>  	int data_len;
>  	int meta_len;
> @@ -487,8 +598,12 @@ void psample_sample_packet(struct psample_group *group, struct sk_buff *skb,
>  #endif
>  
>  	genlmsg_end(nl_skb, data);
> -	genlmsg_multicast_netns(&psample_nl_family, group->net, nl_skb, 0,
> -				PSAMPLE_NL_MCGRP_SAMPLE, GFP_ATOMIC);
> +	psample_nl_obj_desc_init(&desc, group->group_num);
> +	genlmsg_multicast_netns_filtered(&psample_nl_family,
> +					 group->net, nl_skb, 0,
> +					 PSAMPLE_NL_MCGRP_SAMPLE,
> +					 GFP_ATOMIC, psample_nl_sample_filter,
> +					 &desc);
>  
>  	return;
>  error:
Adrián Moreno April 8, 2024, 7:24 p.m. UTC | #2
On 4/8/24 15:18, Ilya Maximets wrote:
> [copying my previous reply since this version actually has netdev@ in Cc]
> 
> On 4/8/24 14:57, Adrian Moreno wrote:
>> Packet samples can come from several places (e.g: different tc sample
>> actions), typically using the sample group (PSAMPLE_ATTR_SAMPLE_GROUP)
>> to differentiate them.
>>
>> Likewise, sample consumers that listen on the multicast group may only
>> be interested on a single group. However, they are currently forced to
>> receive all samples and discard the ones that are not relevant, causing
>> unnecessary overhead.
>>
>> Allow users to filter on the desired group_id by adding a new command
>> SAMPLE_FILTER_SET that can be used to pass the desired group id.
>> Store this filter on the per-socket private pointer and use it for
>> filtering multicasted samples.
>>
>> Signed-off-by: Adrian Moreno <amorenoz@redhat.com>
>> ---
>>   include/uapi/linux/psample.h |   1 +
>>   net/psample/psample.c        | 127 +++++++++++++++++++++++++++++++++--
>>   2 files changed, 122 insertions(+), 6 deletions(-)
>>
>> diff --git a/include/uapi/linux/psample.h b/include/uapi/linux/psample.h
>> index e585db5bf2d2..5e0305b1520d 100644
>> --- a/include/uapi/linux/psample.h
>> +++ b/include/uapi/linux/psample.h
>> @@ -28,6 +28,7 @@ enum psample_command {
>>   	PSAMPLE_CMD_GET_GROUP,
>>   	PSAMPLE_CMD_NEW_GROUP,
>>   	PSAMPLE_CMD_DEL_GROUP,
>> +	PSAMPLE_CMD_SAMPLE_FILTER_SET,
> Other commands are names as PSAMPLE_CMD_VERB_NOUN, so this new one
> should be PSAMPLE_CMD_SET_FILTER.  (The SAMPLE part seems unnecessary.)
> Some functions/structures need to be renamed accordingly.
> 

Sure, I'll rename it when I sent the next version.

>>   };
>>   
>>   enum psample_tunnel_key_attr {
>> diff --git a/net/psample/psample.c b/net/psample/psample.c
>> index a5d9b8446f77..a0cef63dfdec 100644
>> --- a/net/psample/psample.c
>> +++ b/net/psample/psample.c
>> @@ -98,13 +98,84 @@ static int psample_nl_cmd_get_group_dumpit(struct sk_buff *msg,
>>   	return msg->len;
>>   }
>>   
>> -static const struct genl_small_ops psample_nl_ops[] = {
>> +struct psample_obj_desc {
>> +	struct rcu_head rcu;
>> +	u32 group_num;
>> +	bool group_num_valid;
>> +};
>> +
>> +struct psample_nl_sock_priv {
>> +	struct psample_obj_desc __rcu *flt;
> 
> Can we call it 'fileter' ?  I find it hard to read the code with
> this unnecessary abbreviation.  Same for the lock below.
> 

Sure.

>> +	spinlock_t flt_lock; /* Protects flt. */
>> +};
>> +
>> +static void psample_nl_sock_priv_init(void *priv)
>> +{
>> +	struct psample_nl_sock_priv *sk_priv = priv;
>> +
>> +	spin_lock_init(&sk_priv->flt_lock);
>> +}
>> +
>> +static void psample_nl_sock_priv_destroy(void *priv)
>> +{
>> +	struct psample_nl_sock_priv *sk_priv = priv;
>> +	struct psample_obj_desc *flt;
>> +
>> +	flt = rcu_dereference_protected(sk_priv->flt, true);
>> +	kfree_rcu(flt, rcu);
>> +}
>> +
>> +static int psample_nl_sample_filter_set_doit(struct sk_buff *skb,
>> +					     struct genl_info *info)
>> +{
>> +	struct psample_nl_sock_priv *sk_priv;
>> +	struct nlattr **attrs = info->attrs;
>> +	struct psample_obj_desc *flt;
>> +
>> +	flt = kzalloc(sizeof(*flt), GFP_KERNEL);
>> +
>> +	if (attrs[PSAMPLE_ATTR_SAMPLE_GROUP]) {
>> +		flt->group_num = nla_get_u32(attrs[PSAMPLE_ATTR_SAMPLE_GROUP]);
>> +		flt->group_num_valid = true;
>> +	}
>> +
>> +	if (!flt->group_num_valid) {
>> +		kfree(flt);
> 
> Might be better to not allocate it in the first place.
> 

Absolutely.

>> +		flt = NULL;
>> +	}
>> +
>> +	sk_priv = genl_sk_priv_get(&psample_nl_family, NETLINK_CB(skb).sk);
>> +	if (IS_ERR(sk_priv)) {
>> +		kfree(flt);
>> +		return PTR_ERR(sk_priv);
>> +	}
>> +
>> +	spin_lock(&sk_priv->flt_lock);
>> +	flt = rcu_replace_pointer(sk_priv->flt, flt,
>> +				  lockdep_is_held(&sk_priv->flt_lock));
>> +	spin_unlock(&sk_priv->flt_lock);
>> +	kfree_rcu(flt, rcu);
>> +	return 0;
>> +}
>> +
>> +static const struct nla_policy
>> +	psample_sample_filter_set_policy[PSAMPLE_ATTR_SAMPLE_GROUP + 1] = {
>> +	[PSAMPLE_ATTR_SAMPLE_GROUP] = { .type = NLA_U32, },
> 
> This indentation is confusing, though I'm not sure what's a better way.
> 

I now! I'll try to move it around see if it improves things.

>> +};
>> +
>> +static const struct genl_ops psample_nl_ops[] = {
>>   	{
>>   		.cmd = PSAMPLE_CMD_GET_GROUP,
>>   		.validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP,
>>   		.dumpit = psample_nl_cmd_get_group_dumpit,
>>   		/* can be retrieved by unprivileged users */
>> -	}
>> +	},
>> +	{
>> +		.cmd		= PSAMPLE_CMD_SAMPLE_FILTER_SET,
>> +		.doit		= psample_nl_sample_filter_set_doit,
>> +		.policy		= psample_sample_filter_set_policy,
>> +		.flags		= 0,
>> +	},
>>   };
>>   
>>   static struct genl_family psample_nl_family __ro_after_init = {
>> @@ -114,10 +185,13 @@ static struct genl_family psample_nl_family __ro_after_init = {
>>   	.netnsok	= true,
>>   	.module		= THIS_MODULE,
>>   	.mcgrps		= psample_nl_mcgrps,
>> -	.small_ops	= psample_nl_ops,
>> -	.n_small_ops	= ARRAY_SIZE(psample_nl_ops),
>> +	.ops		= psample_nl_ops,
>> +	.n_ops		= ARRAY_SIZE(psample_nl_ops),
>>   	.resv_start_op	= PSAMPLE_CMD_GET_GROUP + 1,
>>   	.n_mcgrps	= ARRAY_SIZE(psample_nl_mcgrps),
>> +	.sock_priv_size		= sizeof(struct psample_nl_sock_priv),
>> +	.sock_priv_init		= psample_nl_sock_priv_init,
>> +	.sock_priv_destroy	= psample_nl_sock_priv_destroy,
>>   };
>>   
>>   static void psample_group_notify(struct psample_group *group,
>> @@ -360,6 +434,42 @@ static int psample_tunnel_meta_len(struct ip_tunnel_info *tun_info)
>>   }
>>   #endif
>>   
>> +static inline void psample_nl_obj_desc_init(struct psample_obj_desc *desc,
>> +					    u32 group_num)
>> +{
>> +	memset(desc, 0, sizeof(*desc));
>> +	desc->group_num = group_num;
>> +	desc->group_num_valid = true;
>> +}
>> +
>> +static bool psample_obj_desc_match(struct psample_obj_desc *desc,
>> +				   struct psample_obj_desc *flt)
>> +{
>> +	if (desc->group_num_valid && flt->group_num_valid &&
>> +	    desc->group_num != flt->group_num)
>> +		return false;
>> +	return true;
> 
> This fucntion returns 'true' if one of the arguments is not valid.
> I'd not expect such behavior from a 'match' function.
> 
> I understand the intention that psample should sample everything
> to sockets that do not request filters, but that should not be part
> of the 'match' logic, or more appropriate function name should be
> chosen.  Also, if the group is not initialized, but the filter is,
> it should not match, logically.  The validity on filter and the
> current sample is not symmetric.
> 

The descriptor should always be initialized but I think double checking should 
be OK as in the context of this particular function, it might not be clear it is.

> And I'm not really sure if the 'group_num_valid' is actually needed.
> Can the NULL pointer be used as an indicator?  If so, then maybe
> the whole psample_obj_desc structure is not needed as it will
> contain a single field.

If we only filter on group_id, then yes. However, as I was writing this, I 
thought maybe opening the door to filtering on more fields such as the protocol 
in/out interfaces, etc. Now that I read this I understand the current code is 
confusing: I should have left a comment or mention it in the commit message.

> 
>> +}
>> +
>> +static int psample_nl_sample_filter(struct sock *dsk, struct sk_buff *skb,
>> +				    void *data)
>> +{
>> +	struct psample_obj_desc *desc = data;
>> +	struct psample_nl_sock_priv *sk_priv;
>> +	struct psample_obj_desc *flt;
>> +	int ret = 0;
>> +
>> +	rcu_read_lock();
>> +	sk_priv = __genl_sk_priv_get(&psample_nl_family, dsk);
>> +	if (!IS_ERR_OR_NULL(sk_priv)) {
>> +		flt = rcu_dereference(sk_priv->flt);
>> +		if (flt)
>> +			ret = !psample_obj_desc_match(desc, flt);
>> +	}
>> +	rcu_read_unlock();
>> +	return ret;
>> +}
>> +
>>   void psample_sample_packet(struct psample_group *group, struct sk_buff *skb,
>>   			   u32 sample_rate, const struct psample_metadata *md)
>>   {
>> @@ -370,6 +480,7 @@ void psample_sample_packet(struct psample_group *group, struct sk_buff *skb,
>>   #ifdef CONFIG_INET
>>   	struct ip_tunnel_info *tun_info;
>>   #endif
>> +	struct psample_obj_desc desc;
>>   	struct sk_buff *nl_skb;
>>   	int data_len;
>>   	int meta_len;
>> @@ -487,8 +598,12 @@ void psample_sample_packet(struct psample_group *group, struct sk_buff *skb,
>>   #endif
>>   
>>   	genlmsg_end(nl_skb, data);
>> -	genlmsg_multicast_netns(&psample_nl_family, group->net, nl_skb, 0,
>> -				PSAMPLE_NL_MCGRP_SAMPLE, GFP_ATOMIC);
>> +	psample_nl_obj_desc_init(&desc, group->group_num);
>> +	genlmsg_multicast_netns_filtered(&psample_nl_family,
>> +					 group->net, nl_skb, 0,
>> +					 PSAMPLE_NL_MCGRP_SAMPLE,
>> +					 GFP_ATOMIC, psample_nl_sample_filter,
>> +					 &desc);
>>   
>>   	return;
>>   error:
>
Aaron Conole April 9, 2024, 2:43 p.m. UTC | #3
Adrian Moreno <amorenoz@redhat.com> writes:

> On 4/8/24 15:18, Ilya Maximets wrote:
>> [copying my previous reply since this version actually has netdev@ in Cc]
>> On 4/8/24 14:57, Adrian Moreno wrote:
>>> Packet samples can come from several places (e.g: different tc sample
>>> actions), typically using the sample group (PSAMPLE_ATTR_SAMPLE_GROUP)
>>> to differentiate them.
>>>
>>> Likewise, sample consumers that listen on the multicast group may only
>>> be interested on a single group. However, they are currently forced to
>>> receive all samples and discard the ones that are not relevant, causing
>>> unnecessary overhead.
>>>
>>> Allow users to filter on the desired group_id by adding a new command
>>> SAMPLE_FILTER_SET that can be used to pass the desired group id.
>>> Store this filter on the per-socket private pointer and use it for
>>> filtering multicasted samples.
>>>
>>> Signed-off-by: Adrian Moreno <amorenoz@redhat.com>
>>> ---
>>>   include/uapi/linux/psample.h |   1 +
>>>   net/psample/psample.c        | 127 +++++++++++++++++++++++++++++++++--
>>>   2 files changed, 122 insertions(+), 6 deletions(-)
>>>
>>> diff --git a/include/uapi/linux/psample.h b/include/uapi/linux/psample.h
>>> index e585db5bf2d2..5e0305b1520d 100644
>>> --- a/include/uapi/linux/psample.h
>>> +++ b/include/uapi/linux/psample.h
>>> @@ -28,6 +28,7 @@ enum psample_command {
>>>   	PSAMPLE_CMD_GET_GROUP,
>>>   	PSAMPLE_CMD_NEW_GROUP,
>>>   	PSAMPLE_CMD_DEL_GROUP,
>>> +	PSAMPLE_CMD_SAMPLE_FILTER_SET,
>> Other commands are names as PSAMPLE_CMD_VERB_NOUN, so this new one
>> should be PSAMPLE_CMD_SET_FILTER.  (The SAMPLE part seems unnecessary.)
>> Some functions/structures need to be renamed accordingly.
>> 
>
> Sure, I'll rename it when I sent the next version.
>
>>>   };
>>>     enum psample_tunnel_key_attr {
>>> diff --git a/net/psample/psample.c b/net/psample/psample.c
>>> index a5d9b8446f77..a0cef63dfdec 100644
>>> --- a/net/psample/psample.c
>>> +++ b/net/psample/psample.c
>>> @@ -98,13 +98,84 @@ static int psample_nl_cmd_get_group_dumpit(struct sk_buff *msg,
>>>   	return msg->len;
>>>   }
>>>   -static const struct genl_small_ops psample_nl_ops[] = {
>>> +struct psample_obj_desc {
>>> +	struct rcu_head rcu;
>>> +	u32 group_num;
>>> +	bool group_num_valid;
>>> +};
>>> +
>>> +struct psample_nl_sock_priv {
>>> +	struct psample_obj_desc __rcu *flt;
>> Can we call it 'fileter' ?  I find it hard to read the code with
>> this unnecessary abbreviation.  Same for the lock below.
>> 
>
> Sure.
>
>>> +	spinlock_t flt_lock; /* Protects flt. */
>>> +};
>>> +
>>> +static void psample_nl_sock_priv_init(void *priv)
>>> +{
>>> +	struct psample_nl_sock_priv *sk_priv = priv;
>>> +
>>> +	spin_lock_init(&sk_priv->flt_lock);
>>> +}
>>> +
>>> +static void psample_nl_sock_priv_destroy(void *priv)
>>> +{
>>> +	struct psample_nl_sock_priv *sk_priv = priv;
>>> +	struct psample_obj_desc *flt;
>>> +
>>> +	flt = rcu_dereference_protected(sk_priv->flt, true);
>>> +	kfree_rcu(flt, rcu);
>>> +}
>>> +
>>> +static int psample_nl_sample_filter_set_doit(struct sk_buff *skb,
>>> +					     struct genl_info *info)
>>> +{
>>> +	struct psample_nl_sock_priv *sk_priv;
>>> +	struct nlattr **attrs = info->attrs;
>>> +	struct psample_obj_desc *flt;
>>> +
>>> +	flt = kzalloc(sizeof(*flt), GFP_KERNEL);
>>> +
>>> +	if (attrs[PSAMPLE_ATTR_SAMPLE_GROUP]) {
>>> +		flt->group_num = nla_get_u32(attrs[PSAMPLE_ATTR_SAMPLE_GROUP]);
>>> +		flt->group_num_valid = true;
>>> +	}
>>> +
>>> +	if (!flt->group_num_valid) {
>>> +		kfree(flt);
>> Might be better to not allocate it in the first place.
>> 
>
> Absolutely.
>
>>> +		flt = NULL;
>>> +	}
>>> +
>>> +	sk_priv = genl_sk_priv_get(&psample_nl_family, NETLINK_CB(skb).sk);
>>> +	if (IS_ERR(sk_priv)) {
>>> +		kfree(flt);
>>> +		return PTR_ERR(sk_priv);
>>> +	}
>>> +
>>> +	spin_lock(&sk_priv->flt_lock);
>>> +	flt = rcu_replace_pointer(sk_priv->flt, flt,
>>> +				  lockdep_is_held(&sk_priv->flt_lock));
>>> +	spin_unlock(&sk_priv->flt_lock);
>>> +	kfree_rcu(flt, rcu);
>>> +	return 0;
>>> +}
>>> +
>>> +static const struct nla_policy
>>> +	psample_sample_filter_set_policy[PSAMPLE_ATTR_SAMPLE_GROUP + 1] = {
>>> +	[PSAMPLE_ATTR_SAMPLE_GROUP] = { .type = NLA_U32, },
>> This indentation is confusing, though I'm not sure what's a better
>> way.
>> 
>
> I now! I'll try to move it around see if it improves things.
>
>>> +};
>>> +
>>> +static const struct genl_ops psample_nl_ops[] = {
>>>   	{
>>>   		.cmd = PSAMPLE_CMD_GET_GROUP,
>>>   		.validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP,
>>>   		.dumpit = psample_nl_cmd_get_group_dumpit,
>>>   		/* can be retrieved by unprivileged users */
>>> -	}
>>> +	},
>>> +	{
>>> +		.cmd		= PSAMPLE_CMD_SAMPLE_FILTER_SET,
>>> +		.doit		= psample_nl_sample_filter_set_doit,
>>> +		.policy		= psample_sample_filter_set_policy,
>>> +		.flags		= 0,
>>> +	},
>>>   };
>>>     static struct genl_family psample_nl_family __ro_after_init = {
>>> @@ -114,10 +185,13 @@ static struct genl_family psample_nl_family __ro_after_init = {
>>>   	.netnsok	= true,
>>>   	.module		= THIS_MODULE,
>>>   	.mcgrps		= psample_nl_mcgrps,
>>> -	.small_ops	= psample_nl_ops,
>>> -	.n_small_ops	= ARRAY_SIZE(psample_nl_ops),
>>> +	.ops		= psample_nl_ops,
>>> +	.n_ops		= ARRAY_SIZE(psample_nl_ops),
>>>   	.resv_start_op	= PSAMPLE_CMD_GET_GROUP + 1,
>>>   	.n_mcgrps	= ARRAY_SIZE(psample_nl_mcgrps),
>>> +	.sock_priv_size		= sizeof(struct psample_nl_sock_priv),
>>> +	.sock_priv_init		= psample_nl_sock_priv_init,
>>> +	.sock_priv_destroy	= psample_nl_sock_priv_destroy,
>>>   };
>>>     static void psample_group_notify(struct psample_group *group,
>>> @@ -360,6 +434,42 @@ static int psample_tunnel_meta_len(struct ip_tunnel_info *tun_info)
>>>   }
>>>   #endif
>>>   +static inline void psample_nl_obj_desc_init(struct
>>> psample_obj_desc *desc,
>>> +					    u32 group_num)
>>> +{
>>> +	memset(desc, 0, sizeof(*desc));
>>> +	desc->group_num = group_num;
>>> +	desc->group_num_valid = true;
>>> +}
>>> +
>>> +static bool psample_obj_desc_match(struct psample_obj_desc *desc,
>>> +				   struct psample_obj_desc *flt)
>>> +{
>>> +	if (desc->group_num_valid && flt->group_num_valid &&
>>> +	    desc->group_num != flt->group_num)
>>> +		return false;
>>> +	return true;
>> This fucntion returns 'true' if one of the arguments is not valid.
>> I'd not expect such behavior from a 'match' function.
>> I understand the intention that psample should sample everything
>> to sockets that do not request filters, but that should not be part
>> of the 'match' logic, or more appropriate function name should be
>> chosen.  Also, if the group is not initialized, but the filter is,
>> it should not match, logically.  The validity on filter and the
>> current sample is not symmetric.
>> 
>
> The descriptor should always be initialized but I think double
> checking should be OK as in the context of this particular function,
> it might not be clear it is.
>
>> And I'm not really sure if the 'group_num_valid' is actually needed.
>> Can the NULL pointer be used as an indicator?  If so, then maybe
>> the whole psample_obj_desc structure is not needed as it will
>> contain a single field.
>
> If we only filter on group_id, then yes. However, as I was writing
> this, I thought maybe opening the door to filtering on more fields
> such as the protocol in/out interfaces, etc. Now that I read this I
> understand the current code is confusing: I should have left a comment
> or mention it in the commit message.

If you want to have such filtering options, does it make sense to
instead have the listening program send a set of bpf instructions for
filtering instead?  I think the data should be available at the point
where simple bpf is attached (SO_ATTACH_BPF to the psample socket, and
the filter should run as part of the broadcast message IIRC since it
populates the sk_filter field).

>> 
>>> +}
>>> +
>>> +static int psample_nl_sample_filter(struct sock *dsk, struct sk_buff *skb,
>>> +				    void *data)
>>> +{
>>> +	struct psample_obj_desc *desc = data;
>>> +	struct psample_nl_sock_priv *sk_priv;
>>> +	struct psample_obj_desc *flt;
>>> +	int ret = 0;
>>> +
>>> +	rcu_read_lock();
>>> +	sk_priv = __genl_sk_priv_get(&psample_nl_family, dsk);
>>> +	if (!IS_ERR_OR_NULL(sk_priv)) {
>>> +		flt = rcu_dereference(sk_priv->flt);
>>> +		if (flt)
>>> +			ret = !psample_obj_desc_match(desc, flt);
>>> +	}
>>> +	rcu_read_unlock();
>>> +	return ret;
>>> +}
>>> +
>>>   void psample_sample_packet(struct psample_group *group, struct sk_buff *skb,
>>>   			   u32 sample_rate, const struct psample_metadata *md)
>>>   {
>>> @@ -370,6 +480,7 @@ void psample_sample_packet(struct psample_group *group, struct sk_buff *skb,
>>>   #ifdef CONFIG_INET
>>>   	struct ip_tunnel_info *tun_info;
>>>   #endif
>>> +	struct psample_obj_desc desc;
>>>   	struct sk_buff *nl_skb;
>>>   	int data_len;
>>>   	int meta_len;
>>> @@ -487,8 +598,12 @@ void psample_sample_packet(struct psample_group *group, struct sk_buff *skb,
>>>   #endif
>>>     	genlmsg_end(nl_skb, data);
>>> -	genlmsg_multicast_netns(&psample_nl_family, group->net, nl_skb, 0,
>>> -				PSAMPLE_NL_MCGRP_SAMPLE, GFP_ATOMIC);
>>> +	psample_nl_obj_desc_init(&desc, group->group_num);
>>> +	genlmsg_multicast_netns_filtered(&psample_nl_family,
>>> +					 group->net, nl_skb, 0,
>>> +					 PSAMPLE_NL_MCGRP_SAMPLE,
>>> +					 GFP_ATOMIC, psample_nl_sample_filter,
>>> +					 &desc);
>>>     	return;
>>>   error:
>>
Ido Schimmel April 10, 2024, 1:06 p.m. UTC | #4
On Mon, Apr 08, 2024 at 02:57:41PM +0200, Adrian Moreno wrote:
> Packet samples can come from several places (e.g: different tc sample
> actions), typically using the sample group (PSAMPLE_ATTR_SAMPLE_GROUP)
> to differentiate them.
> 
> Likewise, sample consumers that listen on the multicast group may only
> be interested on a single group. However, they are currently forced to
> receive all samples and discard the ones that are not relevant, causing
> unnecessary overhead.
> 
> Allow users to filter on the desired group_id by adding a new command
> SAMPLE_FILTER_SET that can be used to pass the desired group id.
> Store this filter on the per-socket private pointer and use it for
> filtering multicasted samples.

Did you consider using BPF for this type of filtering instead of new
uAPI?

See example here:
https://github.com/Mellanox/libpsample/blob/master/src/psample.c#L290
Adrián Moreno April 10, 2024, 1:32 p.m. UTC | #5
On 4/9/24 16:43, Aaron Conole wrote:
> Adrian Moreno <amorenoz@redhat.com> writes:
> 
>> On 4/8/24 15:18, Ilya Maximets wrote:
>>> [copying my previous reply since this version actually has netdev@ in Cc]
>>> On 4/8/24 14:57, Adrian Moreno wrote:
>>>> Packet samples can come from several places (e.g: different tc sample
>>>> actions), typically using the sample group (PSAMPLE_ATTR_SAMPLE_GROUP)
>>>> to differentiate them.
>>>>
>>>> Likewise, sample consumers that listen on the multicast group may only
>>>> be interested on a single group. However, they are currently forced to
>>>> receive all samples and discard the ones that are not relevant, causing
>>>> unnecessary overhead.
>>>>
>>>> Allow users to filter on the desired group_id by adding a new command
>>>> SAMPLE_FILTER_SET that can be used to pass the desired group id.
>>>> Store this filter on the per-socket private pointer and use it for
>>>> filtering multicasted samples.
>>>>
>>>> Signed-off-by: Adrian Moreno <amorenoz@redhat.com>
>>>> ---
>>>>    include/uapi/linux/psample.h |   1 +
>>>>    net/psample/psample.c        | 127 +++++++++++++++++++++++++++++++++--
>>>>    2 files changed, 122 insertions(+), 6 deletions(-)
>>>>
>>>> diff --git a/include/uapi/linux/psample.h b/include/uapi/linux/psample.h
>>>> index e585db5bf2d2..5e0305b1520d 100644
>>>> --- a/include/uapi/linux/psample.h
>>>> +++ b/include/uapi/linux/psample.h
>>>> @@ -28,6 +28,7 @@ enum psample_command {
>>>>    	PSAMPLE_CMD_GET_GROUP,
>>>>    	PSAMPLE_CMD_NEW_GROUP,
>>>>    	PSAMPLE_CMD_DEL_GROUP,
>>>> +	PSAMPLE_CMD_SAMPLE_FILTER_SET,
>>> Other commands are names as PSAMPLE_CMD_VERB_NOUN, so this new one
>>> should be PSAMPLE_CMD_SET_FILTER.  (The SAMPLE part seems unnecessary.)
>>> Some functions/structures need to be renamed accordingly.
>>>
>>
>> Sure, I'll rename it when I sent the next version.
>>
>>>>    };
>>>>      enum psample_tunnel_key_attr {
>>>> diff --git a/net/psample/psample.c b/net/psample/psample.c
>>>> index a5d9b8446f77..a0cef63dfdec 100644
>>>> --- a/net/psample/psample.c
>>>> +++ b/net/psample/psample.c
>>>> @@ -98,13 +98,84 @@ static int psample_nl_cmd_get_group_dumpit(struct sk_buff *msg,
>>>>    	return msg->len;
>>>>    }
>>>>    -static const struct genl_small_ops psample_nl_ops[] = {
>>>> +struct psample_obj_desc {
>>>> +	struct rcu_head rcu;
>>>> +	u32 group_num;
>>>> +	bool group_num_valid;
>>>> +};
>>>> +
>>>> +struct psample_nl_sock_priv {
>>>> +	struct psample_obj_desc __rcu *flt;
>>> Can we call it 'fileter' ?  I find it hard to read the code with
>>> this unnecessary abbreviation.  Same for the lock below.
>>>
>>
>> Sure.
>>
>>>> +	spinlock_t flt_lock; /* Protects flt. */
>>>> +};
>>>> +
>>>> +static void psample_nl_sock_priv_init(void *priv)
>>>> +{
>>>> +	struct psample_nl_sock_priv *sk_priv = priv;
>>>> +
>>>> +	spin_lock_init(&sk_priv->flt_lock);
>>>> +}
>>>> +
>>>> +static void psample_nl_sock_priv_destroy(void *priv)
>>>> +{
>>>> +	struct psample_nl_sock_priv *sk_priv = priv;
>>>> +	struct psample_obj_desc *flt;
>>>> +
>>>> +	flt = rcu_dereference_protected(sk_priv->flt, true);
>>>> +	kfree_rcu(flt, rcu);
>>>> +}
>>>> +
>>>> +static int psample_nl_sample_filter_set_doit(struct sk_buff *skb,
>>>> +					     struct genl_info *info)
>>>> +{
>>>> +	struct psample_nl_sock_priv *sk_priv;
>>>> +	struct nlattr **attrs = info->attrs;
>>>> +	struct psample_obj_desc *flt;
>>>> +
>>>> +	flt = kzalloc(sizeof(*flt), GFP_KERNEL);
>>>> +
>>>> +	if (attrs[PSAMPLE_ATTR_SAMPLE_GROUP]) {
>>>> +		flt->group_num = nla_get_u32(attrs[PSAMPLE_ATTR_SAMPLE_GROUP]);
>>>> +		flt->group_num_valid = true;
>>>> +	}
>>>> +
>>>> +	if (!flt->group_num_valid) {
>>>> +		kfree(flt);
>>> Might be better to not allocate it in the first place.
>>>
>>
>> Absolutely.
>>
>>>> +		flt = NULL;
>>>> +	}
>>>> +
>>>> +	sk_priv = genl_sk_priv_get(&psample_nl_family, NETLINK_CB(skb).sk);
>>>> +	if (IS_ERR(sk_priv)) {
>>>> +		kfree(flt);
>>>> +		return PTR_ERR(sk_priv);
>>>> +	}
>>>> +
>>>> +	spin_lock(&sk_priv->flt_lock);
>>>> +	flt = rcu_replace_pointer(sk_priv->flt, flt,
>>>> +				  lockdep_is_held(&sk_priv->flt_lock));
>>>> +	spin_unlock(&sk_priv->flt_lock);
>>>> +	kfree_rcu(flt, rcu);
>>>> +	return 0;
>>>> +}
>>>> +
>>>> +static const struct nla_policy
>>>> +	psample_sample_filter_set_policy[PSAMPLE_ATTR_SAMPLE_GROUP + 1] = {
>>>> +	[PSAMPLE_ATTR_SAMPLE_GROUP] = { .type = NLA_U32, },
>>> This indentation is confusing, though I'm not sure what's a better
>>> way.
>>>
>>
>> I now! I'll try to move it around see if it improves things.
>>
>>>> +};
>>>> +
>>>> +static const struct genl_ops psample_nl_ops[] = {
>>>>    	{
>>>>    		.cmd = PSAMPLE_CMD_GET_GROUP,
>>>>    		.validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP,
>>>>    		.dumpit = psample_nl_cmd_get_group_dumpit,
>>>>    		/* can be retrieved by unprivileged users */
>>>> -	}
>>>> +	},
>>>> +	{
>>>> +		.cmd		= PSAMPLE_CMD_SAMPLE_FILTER_SET,
>>>> +		.doit		= psample_nl_sample_filter_set_doit,
>>>> +		.policy		= psample_sample_filter_set_policy,
>>>> +		.flags		= 0,
>>>> +	},
>>>>    };
>>>>      static struct genl_family psample_nl_family __ro_after_init = {
>>>> @@ -114,10 +185,13 @@ static struct genl_family psample_nl_family __ro_after_init = {
>>>>    	.netnsok	= true,
>>>>    	.module		= THIS_MODULE,
>>>>    	.mcgrps		= psample_nl_mcgrps,
>>>> -	.small_ops	= psample_nl_ops,
>>>> -	.n_small_ops	= ARRAY_SIZE(psample_nl_ops),
>>>> +	.ops		= psample_nl_ops,
>>>> +	.n_ops		= ARRAY_SIZE(psample_nl_ops),
>>>>    	.resv_start_op	= PSAMPLE_CMD_GET_GROUP + 1,
>>>>    	.n_mcgrps	= ARRAY_SIZE(psample_nl_mcgrps),
>>>> +	.sock_priv_size		= sizeof(struct psample_nl_sock_priv),
>>>> +	.sock_priv_init		= psample_nl_sock_priv_init,
>>>> +	.sock_priv_destroy	= psample_nl_sock_priv_destroy,
>>>>    };
>>>>      static void psample_group_notify(struct psample_group *group,
>>>> @@ -360,6 +434,42 @@ static int psample_tunnel_meta_len(struct ip_tunnel_info *tun_info)
>>>>    }
>>>>    #endif
>>>>    +static inline void psample_nl_obj_desc_init(struct
>>>> psample_obj_desc *desc,
>>>> +					    u32 group_num)
>>>> +{
>>>> +	memset(desc, 0, sizeof(*desc));
>>>> +	desc->group_num = group_num;
>>>> +	desc->group_num_valid = true;
>>>> +}
>>>> +
>>>> +static bool psample_obj_desc_match(struct psample_obj_desc *desc,
>>>> +				   struct psample_obj_desc *flt)
>>>> +{
>>>> +	if (desc->group_num_valid && flt->group_num_valid &&
>>>> +	    desc->group_num != flt->group_num)
>>>> +		return false;
>>>> +	return true;
>>> This fucntion returns 'true' if one of the arguments is not valid.
>>> I'd not expect such behavior from a 'match' function.
>>> I understand the intention that psample should sample everything
>>> to sockets that do not request filters, but that should not be part
>>> of the 'match' logic, or more appropriate function name should be
>>> chosen.  Also, if the group is not initialized, but the filter is,
>>> it should not match, logically.  The validity on filter and the
>>> current sample is not symmetric.
>>>
>>
>> The descriptor should always be initialized but I think double
>> checking should be OK as in the context of this particular function,
>> it might not be clear it is.
>>
>>> And I'm not really sure if the 'group_num_valid' is actually needed.
>>> Can the NULL pointer be used as an indicator?  If so, then maybe
>>> the whole psample_obj_desc structure is not needed as it will
>>> contain a single field.
>>
>> If we only filter on group_id, then yes. However, as I was writing
>> this, I thought maybe opening the door to filtering on more fields
>> such as the protocol in/out interfaces, etc. Now that I read this I
>> understand the current code is confusing: I should have left a comment
>> or mention it in the commit message.
> 
> If you want to have such filtering options, does it make sense to
> instead have the listening program send a set of bpf instructions for
> filtering instead?  I think the data should be available at the point
> where simple bpf is attached (SO_ATTACH_BPF to the psample socket, and
> the filter should run as part of the broadcast message IIRC since it
> populates the sk_filter field).
> 

That's a good point. I hope parsing the netlink messages won't be too cumbersome.
So let's limit it to group_ids. How about filtering on a number of group_ids? Is 
that worth it?


>>>
>>>> +}
>>>> +
>>>> +static int psample_nl_sample_filter(struct sock *dsk, struct sk_buff *skb,
>>>> +				    void *data)
>>>> +{
>>>> +	struct psample_obj_desc *desc = data;
>>>> +	struct psample_nl_sock_priv *sk_priv;
>>>> +	struct psample_obj_desc *flt;
>>>> +	int ret = 0;
>>>> +
>>>> +	rcu_read_lock();
>>>> +	sk_priv = __genl_sk_priv_get(&psample_nl_family, dsk);
>>>> +	if (!IS_ERR_OR_NULL(sk_priv)) {
>>>> +		flt = rcu_dereference(sk_priv->flt);
>>>> +		if (flt)
>>>> +			ret = !psample_obj_desc_match(desc, flt);
>>>> +	}
>>>> +	rcu_read_unlock();
>>>> +	return ret;
>>>> +}
>>>> +
>>>>    void psample_sample_packet(struct psample_group *group, struct sk_buff *skb,
>>>>    			   u32 sample_rate, const struct psample_metadata *md)
>>>>    {
>>>> @@ -370,6 +480,7 @@ void psample_sample_packet(struct psample_group *group, struct sk_buff *skb,
>>>>    #ifdef CONFIG_INET
>>>>    	struct ip_tunnel_info *tun_info;
>>>>    #endif
>>>> +	struct psample_obj_desc desc;
>>>>    	struct sk_buff *nl_skb;
>>>>    	int data_len;
>>>>    	int meta_len;
>>>> @@ -487,8 +598,12 @@ void psample_sample_packet(struct psample_group *group, struct sk_buff *skb,
>>>>    #endif
>>>>      	genlmsg_end(nl_skb, data);
>>>> -	genlmsg_multicast_netns(&psample_nl_family, group->net, nl_skb, 0,
>>>> -				PSAMPLE_NL_MCGRP_SAMPLE, GFP_ATOMIC);
>>>> +	psample_nl_obj_desc_init(&desc, group->group_num);
>>>> +	genlmsg_multicast_netns_filtered(&psample_nl_family,
>>>> +					 group->net, nl_skb, 0,
>>>> +					 PSAMPLE_NL_MCGRP_SAMPLE,
>>>> +					 GFP_ATOMIC, psample_nl_sample_filter,
>>>> +					 &desc);
>>>>      	return;
>>>>    error:
>>>
>
Adrián Moreno April 10, 2024, 1:42 p.m. UTC | #6
On 4/10/24 15:06, Ido Schimmel wrote:
> On Mon, Apr 08, 2024 at 02:57:41PM +0200, Adrian Moreno wrote:
>> Packet samples can come from several places (e.g: different tc sample
>> actions), typically using the sample group (PSAMPLE_ATTR_SAMPLE_GROUP)
>> to differentiate them.
>>
>> Likewise, sample consumers that listen on the multicast group may only
>> be interested on a single group. However, they are currently forced to
>> receive all samples and discard the ones that are not relevant, causing
>> unnecessary overhead.
>>
>> Allow users to filter on the desired group_id by adding a new command
>> SAMPLE_FILTER_SET that can be used to pass the desired group id.
>> Store this filter on the per-socket private pointer and use it for
>> filtering multicasted samples.
> 
> Did you consider using BPF for this type of filtering instead of new
> uAPI?
>

Yes. I ended up going for a uAPI change because, since the group_id is part of 
the psample uAPI semantics, requiring users to load ebpf programs for that 
seemed a bit excessive. Given devlink already uses this mechanism [1], I thought 
it would make things easier for users that already just use netlink.

[1] https://lore.kernel.org/netdev/20231214181549.1270696-9-jiri@resnulli.us/

> See example here:
> https://github.com/Mellanox/libpsample/blob/master/src/psample.c#L290
>
diff mbox series

Patch

diff --git a/include/uapi/linux/psample.h b/include/uapi/linux/psample.h
index e585db5bf2d2..5e0305b1520d 100644
--- a/include/uapi/linux/psample.h
+++ b/include/uapi/linux/psample.h
@@ -28,6 +28,7 @@  enum psample_command {
 	PSAMPLE_CMD_GET_GROUP,
 	PSAMPLE_CMD_NEW_GROUP,
 	PSAMPLE_CMD_DEL_GROUP,
+	PSAMPLE_CMD_SAMPLE_FILTER_SET,
 };
 
 enum psample_tunnel_key_attr {
diff --git a/net/psample/psample.c b/net/psample/psample.c
index a5d9b8446f77..a0cef63dfdec 100644
--- a/net/psample/psample.c
+++ b/net/psample/psample.c
@@ -98,13 +98,84 @@  static int psample_nl_cmd_get_group_dumpit(struct sk_buff *msg,
 	return msg->len;
 }
 
-static const struct genl_small_ops psample_nl_ops[] = {
+struct psample_obj_desc {
+	struct rcu_head rcu;
+	u32 group_num;
+	bool group_num_valid;
+};
+
+struct psample_nl_sock_priv {
+	struct psample_obj_desc __rcu *flt;
+	spinlock_t flt_lock; /* Protects flt. */
+};
+
+static void psample_nl_sock_priv_init(void *priv)
+{
+	struct psample_nl_sock_priv *sk_priv = priv;
+
+	spin_lock_init(&sk_priv->flt_lock);
+}
+
+static void psample_nl_sock_priv_destroy(void *priv)
+{
+	struct psample_nl_sock_priv *sk_priv = priv;
+	struct psample_obj_desc *flt;
+
+	flt = rcu_dereference_protected(sk_priv->flt, true);
+	kfree_rcu(flt, rcu);
+}
+
+static int psample_nl_sample_filter_set_doit(struct sk_buff *skb,
+					     struct genl_info *info)
+{
+	struct psample_nl_sock_priv *sk_priv;
+	struct nlattr **attrs = info->attrs;
+	struct psample_obj_desc *flt;
+
+	flt = kzalloc(sizeof(*flt), GFP_KERNEL);
+
+	if (attrs[PSAMPLE_ATTR_SAMPLE_GROUP]) {
+		flt->group_num = nla_get_u32(attrs[PSAMPLE_ATTR_SAMPLE_GROUP]);
+		flt->group_num_valid = true;
+	}
+
+	if (!flt->group_num_valid) {
+		kfree(flt);
+		flt = NULL;
+	}
+
+	sk_priv = genl_sk_priv_get(&psample_nl_family, NETLINK_CB(skb).sk);
+	if (IS_ERR(sk_priv)) {
+		kfree(flt);
+		return PTR_ERR(sk_priv);
+	}
+
+	spin_lock(&sk_priv->flt_lock);
+	flt = rcu_replace_pointer(sk_priv->flt, flt,
+				  lockdep_is_held(&sk_priv->flt_lock));
+	spin_unlock(&sk_priv->flt_lock);
+	kfree_rcu(flt, rcu);
+	return 0;
+}
+
+static const struct nla_policy
+	psample_sample_filter_set_policy[PSAMPLE_ATTR_SAMPLE_GROUP + 1] = {
+	[PSAMPLE_ATTR_SAMPLE_GROUP] = { .type = NLA_U32, },
+};
+
+static const struct genl_ops psample_nl_ops[] = {
 	{
 		.cmd = PSAMPLE_CMD_GET_GROUP,
 		.validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP,
 		.dumpit = psample_nl_cmd_get_group_dumpit,
 		/* can be retrieved by unprivileged users */
-	}
+	},
+	{
+		.cmd		= PSAMPLE_CMD_SAMPLE_FILTER_SET,
+		.doit		= psample_nl_sample_filter_set_doit,
+		.policy		= psample_sample_filter_set_policy,
+		.flags		= 0,
+	},
 };
 
 static struct genl_family psample_nl_family __ro_after_init = {
@@ -114,10 +185,13 @@  static struct genl_family psample_nl_family __ro_after_init = {
 	.netnsok	= true,
 	.module		= THIS_MODULE,
 	.mcgrps		= psample_nl_mcgrps,
-	.small_ops	= psample_nl_ops,
-	.n_small_ops	= ARRAY_SIZE(psample_nl_ops),
+	.ops		= psample_nl_ops,
+	.n_ops		= ARRAY_SIZE(psample_nl_ops),
 	.resv_start_op	= PSAMPLE_CMD_GET_GROUP + 1,
 	.n_mcgrps	= ARRAY_SIZE(psample_nl_mcgrps),
+	.sock_priv_size		= sizeof(struct psample_nl_sock_priv),
+	.sock_priv_init		= psample_nl_sock_priv_init,
+	.sock_priv_destroy	= psample_nl_sock_priv_destroy,
 };
 
 static void psample_group_notify(struct psample_group *group,
@@ -360,6 +434,42 @@  static int psample_tunnel_meta_len(struct ip_tunnel_info *tun_info)
 }
 #endif
 
+static inline void psample_nl_obj_desc_init(struct psample_obj_desc *desc,
+					    u32 group_num)
+{
+	memset(desc, 0, sizeof(*desc));
+	desc->group_num = group_num;
+	desc->group_num_valid = true;
+}
+
+static bool psample_obj_desc_match(struct psample_obj_desc *desc,
+				   struct psample_obj_desc *flt)
+{
+	if (desc->group_num_valid && flt->group_num_valid &&
+	    desc->group_num != flt->group_num)
+		return false;
+	return true;
+}
+
+static int psample_nl_sample_filter(struct sock *dsk, struct sk_buff *skb,
+				    void *data)
+{
+	struct psample_obj_desc *desc = data;
+	struct psample_nl_sock_priv *sk_priv;
+	struct psample_obj_desc *flt;
+	int ret = 0;
+
+	rcu_read_lock();
+	sk_priv = __genl_sk_priv_get(&psample_nl_family, dsk);
+	if (!IS_ERR_OR_NULL(sk_priv)) {
+		flt = rcu_dereference(sk_priv->flt);
+		if (flt)
+			ret = !psample_obj_desc_match(desc, flt);
+	}
+	rcu_read_unlock();
+	return ret;
+}
+
 void psample_sample_packet(struct psample_group *group, struct sk_buff *skb,
 			   u32 sample_rate, const struct psample_metadata *md)
 {
@@ -370,6 +480,7 @@  void psample_sample_packet(struct psample_group *group, struct sk_buff *skb,
 #ifdef CONFIG_INET
 	struct ip_tunnel_info *tun_info;
 #endif
+	struct psample_obj_desc desc;
 	struct sk_buff *nl_skb;
 	int data_len;
 	int meta_len;
@@ -487,8 +598,12 @@  void psample_sample_packet(struct psample_group *group, struct sk_buff *skb,
 #endif
 
 	genlmsg_end(nl_skb, data);
-	genlmsg_multicast_netns(&psample_nl_family, group->net, nl_skb, 0,
-				PSAMPLE_NL_MCGRP_SAMPLE, GFP_ATOMIC);
+	psample_nl_obj_desc_init(&desc, group->group_num);
+	genlmsg_multicast_netns_filtered(&psample_nl_family,
+					 group->net, nl_skb, 0,
+					 PSAMPLE_NL_MCGRP_SAMPLE,
+					 GFP_ATOMIC, psample_nl_sample_filter,
+					 &desc);
 
 	return;
 error: