diff mbox series

[net] net-timestamp: make sk_tskey more predictable in error path

Message ID 20240210230002.3778461-1-vadfed@meta.com (mailing list archive)
State Superseded
Delegated to: Netdev Maintainers
Headers show
Series [net] net-timestamp: make sk_tskey more predictable in error path | expand

Checks

Context Check Description
netdev/series_format success Single patches do not need cover letters
netdev/tree_selection success Clearly marked for net
netdev/ynl success Generated files up to date; no warnings/errors; no diff in generated;
netdev/fixes_present success Fixes tag present in non-next series
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 978 this patch: 978
netdev/build_tools success No tools touched, skip
netdev/cc_maintainers warning 3 maintainers not CCed: pabeni@redhat.com edumazet@google.com dsahern@kernel.org
netdev/build_clang success Errors and warnings before: 994 this patch: 994
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/deprecated_api success None detected
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success Fixes tag looks correct
netdev/build_allmodconfig_warn success Errors and warnings before: 995 this patch: 995
netdev/checkpatch success total: 0 errors, 0 warnings, 0 checks, 76 lines checked
netdev/build_clang_rust success No Rust files in patch. Skipping build
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/source_inline success Was 0 now: 0
netdev/contest success net-next-2024-02-11--15-00 (tests: 1261)

Commit Message

Vadim Fedorenko Feb. 10, 2024, 11 p.m. UTC
When SOF_TIMESTAMPING_OPT_ID is used to ambiguate timestamped datagrams,
the sk_tskey can become unpredictable in case of any error happened
during sendmsg(). Move increment later in the code and make decrement of
sk_tskey in error path. This solution is still racy in case of multiple
threads doing snedmsg() over the very same socket in parallel, but still
makes error path much more predictable.

Fixes: 09c2d251b707 ("net-timestamp: add key to disambiguate concurrent datagrams")
Reported-by: Andy Lutomirski <luto@amacapital.net>
Signed-off-by: Vadim Fedorenko <vadfed@meta.com>
---
 net/ipv4/ip_output.c  | 14 +++++++++-----
 net/ipv6/ip6_output.c | 14 +++++++++-----
 2 files changed, 18 insertions(+), 10 deletions(-)

Comments

Willem de Bruijn Feb. 11, 2024, 5:42 p.m. UTC | #1
Vadim Fedorenko wrote:
> When SOF_TIMESTAMPING_OPT_ID is used to ambiguate timestamped datagrams,
> the sk_tskey can become unpredictable in case of any error happened
> during sendmsg(). Move increment later in the code and make decrement of
> sk_tskey in error path. This solution is still racy in case of multiple
> threads doing snedmsg() over the very same socket in parallel, but still
> makes error path much more predictable.
> 
> Fixes: 09c2d251b707 ("net-timestamp: add key to disambiguate concurrent datagrams")
> Reported-by: Andy Lutomirski <luto@amacapital.net>
> Signed-off-by: Vadim Fedorenko <vadfed@meta.com>
> ---
>  net/ipv4/ip_output.c  | 14 +++++++++-----
>  net/ipv6/ip6_output.c | 14 +++++++++-----
>  2 files changed, 18 insertions(+), 10 deletions(-)
> 
> diff --git a/net/ipv4/ip_output.c b/net/ipv4/ip_output.c
> index 41537d18eecf..ac4995ed17c7 100644
> --- a/net/ipv4/ip_output.c
> +++ b/net/ipv4/ip_output.c
> @@ -974,7 +974,7 @@ static int __ip_append_data(struct sock *sk,
>  	struct rtable *rt = (struct rtable *)cork->dst;
>  	unsigned int wmem_alloc_delta = 0;
>  	bool paged, extra_uref = false;
> -	u32 tskey = 0;
> +	u32 tsflags, tskey = 0;
>  
>  	skb = skb_peek_tail(queue);
>  
> @@ -982,10 +982,6 @@ static int __ip_append_data(struct sock *sk,
>  	mtu = cork->gso_size ? IP_MAX_MTU : cork->fragsize;
>  	paged = !!cork->gso_size;
>  
> -	if (cork->tx_flags & SKBTX_ANY_TSTAMP &&
> -	    READ_ONCE(sk->sk_tsflags) & SOF_TIMESTAMPING_OPT_ID)
> -		tskey = atomic_inc_return(&sk->sk_tskey) - 1;
> -
>  	hh_len = LL_RESERVED_SPACE(rt->dst.dev);
>  
>  	fragheaderlen = sizeof(struct iphdr) + (opt ? opt->optlen : 0);
> @@ -1052,6 +1048,11 @@ static int __ip_append_data(struct sock *sk,
>  
>  	cork->length += length;
>  
> +	tsflags = READ_ONCE(sk->sk_tsflags);
> +	if (cork->tx_flags & SKBTX_ANY_TSTAMP &&
> +	    tsflags & SOF_TIMESTAMPING_OPT_ID)
> +		tskey = atomic_inc_return(&sk->sk_tskey) - 1;
> +
>  	/* So, what's going on in the loop below?
>  	 *
>  	 * We use calculated fragment length to generate chained skb,
> @@ -1274,6 +1275,9 @@ static int __ip_append_data(struct sock *sk,
>  	cork->length -= length;
>  	IP_INC_STATS(sock_net(sk), IPSTATS_MIB_OUTDISCARDS);
>  	refcount_add(wmem_alloc_delta, &sk->sk_wmem_alloc);
> +	if (cork->tx_flags & SKBTX_ANY_TSTAMP &&
> +	    tsflags & SOF_TIMESTAMPING_OPT_ID)
> +		atomic_dec(&sk->sk_tskey);

Instead of testing the same conditional twice have a local bool,
e.g., hold_tskey? Akin to extra_uarf for MSG_ZEROCOPY.
Vadim Fedorenko Feb. 11, 2024, 11:46 p.m. UTC | #2
On 11/02/2024 12:42, Willem de Bruijn wrote:
> Vadim Fedorenko wrote:
>> When SOF_TIMESTAMPING_OPT_ID is used to ambiguate timestamped datagrams,
>> the sk_tskey can become unpredictable in case of any error happened
>> during sendmsg(). Move increment later in the code and make decrement of
>> sk_tskey in error path. This solution is still racy in case of multiple
>> threads doing snedmsg() over the very same socket in parallel, but still
>> makes error path much more predictable.
>>
>> Fixes: 09c2d251b707 ("net-timestamp: add key to disambiguate concurrent datagrams")
>> Reported-by: Andy Lutomirski <luto@amacapital.net>
>> Signed-off-by: Vadim Fedorenko <vadfed@meta.com>
>> ---
>>   net/ipv4/ip_output.c  | 14 +++++++++-----
>>   net/ipv6/ip6_output.c | 14 +++++++++-----
>>   2 files changed, 18 insertions(+), 10 deletions(-)
>>
>> diff --git a/net/ipv4/ip_output.c b/net/ipv4/ip_output.c
>> index 41537d18eecf..ac4995ed17c7 100644
>> --- a/net/ipv4/ip_output.c
>> +++ b/net/ipv4/ip_output.c
>> @@ -974,7 +974,7 @@ static int __ip_append_data(struct sock *sk,
>>   	struct rtable *rt = (struct rtable *)cork->dst;
>>   	unsigned int wmem_alloc_delta = 0;
>>   	bool paged, extra_uref = false;
>> -	u32 tskey = 0;
>> +	u32 tsflags, tskey = 0;
>>   
>>   	skb = skb_peek_tail(queue);
>>   
>> @@ -982,10 +982,6 @@ static int __ip_append_data(struct sock *sk,
>>   	mtu = cork->gso_size ? IP_MAX_MTU : cork->fragsize;
>>   	paged = !!cork->gso_size;
>>   
>> -	if (cork->tx_flags & SKBTX_ANY_TSTAMP &&
>> -	    READ_ONCE(sk->sk_tsflags) & SOF_TIMESTAMPING_OPT_ID)
>> -		tskey = atomic_inc_return(&sk->sk_tskey) - 1;
>> -
>>   	hh_len = LL_RESERVED_SPACE(rt->dst.dev);
>>   
>>   	fragheaderlen = sizeof(struct iphdr) + (opt ? opt->optlen : 0);
>> @@ -1052,6 +1048,11 @@ static int __ip_append_data(struct sock *sk,
>>   
>>   	cork->length += length;
>>   
>> +	tsflags = READ_ONCE(sk->sk_tsflags);
>> +	if (cork->tx_flags & SKBTX_ANY_TSTAMP &&
>> +	    tsflags & SOF_TIMESTAMPING_OPT_ID)
>> +		tskey = atomic_inc_return(&sk->sk_tskey) - 1;
>> +
>>   	/* So, what's going on in the loop below?
>>   	 *
>>   	 * We use calculated fragment length to generate chained skb,
>> @@ -1274,6 +1275,9 @@ static int __ip_append_data(struct sock *sk,
>>   	cork->length -= length;
>>   	IP_INC_STATS(sock_net(sk), IPSTATS_MIB_OUTDISCARDS);
>>   	refcount_add(wmem_alloc_delta, &sk->sk_wmem_alloc);
>> +	if (cork->tx_flags & SKBTX_ANY_TSTAMP &&
>> +	    tsflags & SOF_TIMESTAMPING_OPT_ID)
>> +		atomic_dec(&sk->sk_tskey);
> 
> Instead of testing the same conditional twice have a local bool,
> e.g., hold_tskey? Akin to extra_uarf for MSG_ZEROCOPY.
> 

Ok, sure, will post v2 soon
diff mbox series

Patch

diff --git a/net/ipv4/ip_output.c b/net/ipv4/ip_output.c
index 41537d18eecf..ac4995ed17c7 100644
--- a/net/ipv4/ip_output.c
+++ b/net/ipv4/ip_output.c
@@ -974,7 +974,7 @@  static int __ip_append_data(struct sock *sk,
 	struct rtable *rt = (struct rtable *)cork->dst;
 	unsigned int wmem_alloc_delta = 0;
 	bool paged, extra_uref = false;
-	u32 tskey = 0;
+	u32 tsflags, tskey = 0;
 
 	skb = skb_peek_tail(queue);
 
@@ -982,10 +982,6 @@  static int __ip_append_data(struct sock *sk,
 	mtu = cork->gso_size ? IP_MAX_MTU : cork->fragsize;
 	paged = !!cork->gso_size;
 
-	if (cork->tx_flags & SKBTX_ANY_TSTAMP &&
-	    READ_ONCE(sk->sk_tsflags) & SOF_TIMESTAMPING_OPT_ID)
-		tskey = atomic_inc_return(&sk->sk_tskey) - 1;
-
 	hh_len = LL_RESERVED_SPACE(rt->dst.dev);
 
 	fragheaderlen = sizeof(struct iphdr) + (opt ? opt->optlen : 0);
@@ -1052,6 +1048,11 @@  static int __ip_append_data(struct sock *sk,
 
 	cork->length += length;
 
+	tsflags = READ_ONCE(sk->sk_tsflags);
+	if (cork->tx_flags & SKBTX_ANY_TSTAMP &&
+	    tsflags & SOF_TIMESTAMPING_OPT_ID)
+		tskey = atomic_inc_return(&sk->sk_tskey) - 1;
+
 	/* So, what's going on in the loop below?
 	 *
 	 * We use calculated fragment length to generate chained skb,
@@ -1274,6 +1275,9 @@  static int __ip_append_data(struct sock *sk,
 	cork->length -= length;
 	IP_INC_STATS(sock_net(sk), IPSTATS_MIB_OUTDISCARDS);
 	refcount_add(wmem_alloc_delta, &sk->sk_wmem_alloc);
+	if (cork->tx_flags & SKBTX_ANY_TSTAMP &&
+	    tsflags & SOF_TIMESTAMPING_OPT_ID)
+		atomic_dec(&sk->sk_tskey);
 	return err;
 }
 
diff --git a/net/ipv6/ip6_output.c b/net/ipv6/ip6_output.c
index a722a43dd668..42e423012c18 100644
--- a/net/ipv6/ip6_output.c
+++ b/net/ipv6/ip6_output.c
@@ -1422,7 +1422,7 @@  static int __ip6_append_data(struct sock *sk,
 	int err;
 	int offset = 0;
 	bool zc = false;
-	u32 tskey = 0;
+	u32 tsflags, tskey = 0;
 	struct rt6_info *rt = (struct rt6_info *)cork->dst;
 	struct ipv6_txoptions *opt = v6_cork->opt;
 	int csummode = CHECKSUM_NONE;
@@ -1440,10 +1440,6 @@  static int __ip6_append_data(struct sock *sk,
 	mtu = cork->gso_size ? IP6_MAX_MTU : cork->fragsize;
 	orig_mtu = mtu;
 
-	if (cork->tx_flags & SKBTX_ANY_TSTAMP &&
-	    READ_ONCE(sk->sk_tsflags) & SOF_TIMESTAMPING_OPT_ID)
-		tskey = atomic_inc_return(&sk->sk_tskey) - 1;
-
 	hh_len = LL_RESERVED_SPACE(rt->dst.dev);
 
 	fragheaderlen = sizeof(struct ipv6hdr) + rt->rt6i_nfheader_len +
@@ -1538,6 +1534,11 @@  static int __ip6_append_data(struct sock *sk,
 			flags &= ~MSG_SPLICE_PAGES;
 	}
 
+	tsflags = READ_ONCE(sk->sk_tsflags);
+	if (cork->tx_flags & SKBTX_ANY_TSTAMP &&
+	    tsflags & SOF_TIMESTAMPING_OPT_ID)
+		tskey = atomic_inc_return(&sk->sk_tskey) - 1;
+
 	/*
 	 * Let's try using as much space as possible.
 	 * Use MTU if total length of the message fits into the MTU.
@@ -1794,6 +1795,9 @@  static int __ip6_append_data(struct sock *sk,
 	cork->length -= length;
 	IP6_INC_STATS(sock_net(sk), rt->rt6i_idev, IPSTATS_MIB_OUTDISCARDS);
 	refcount_add(wmem_alloc_delta, &sk->sk_wmem_alloc);
+	if (cork->tx_flags & SKBTX_ANY_TSTAMP &&
+	    tsflags & SOF_TIMESTAMPING_OPT_ID)
+		atomic_dec(&sk->sk_tskey);
 	return err;
 }