diff mbox series

[net-next,v5,2/4] sock: support put_cmsg to userspace in TX path

Message ID 20240613233133.2463193-3-zijianzhang@bytedance.com (mailing list archive)
State Superseded
Delegated to: Netdev Maintainers
Headers show
Series net: A lightweight zero-copy notification | expand

Checks

Context Check Description
netdev/series_format success Posting correctly formatted
netdev/tree_selection success Clearly marked for net-next, async
netdev/ynl success Generated files up to date; no warnings/errors; no diff in generated;
netdev/fixes_present success Fixes tag not required for -next series
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 6245 this patch: 6245
netdev/build_tools success Errors and warnings before: 0 this patch: 0
netdev/cc_maintainers warning 4 maintainers not CCed: axboe@kernel.dk pabeni@redhat.com kuniyu@amazon.com kuba@kernel.org
netdev/build_clang success Errors and warnings before: 2002 this patch: 2002
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/deprecated_api success None detected
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success No Fixes tag
netdev/build_allmodconfig_warn success Errors and warnings before: 16144 this patch: 16144
netdev/checkpatch warning CHECK: extern prototypes should be avoided in .h files WARNING: line length of 82 exceeds 80 columns WARNING: line length of 83 exceeds 80 columns WARNING: line length of 89 exceeds 80 columns WARNING: line length of 90 exceeds 80 columns WARNING: line length of 91 exceeds 80 columns WARNING: line length of 97 exceeds 80 columns
netdev/build_clang_rust success No Rust files in patch. Skipping build
netdev/kdoc success Errors and warnings before: 24 this patch: 24
netdev/source_inline success Was 0 now: 0
netdev/contest success net-next-2024-06-16--18-00 (tests: 659)

Commit Message

Zijian Zhang June 13, 2024, 11:31 p.m. UTC
From: Zijian Zhang <zijianzhang@bytedance.com>

Since ____sys_sendmsg creates a kernel copy of msg_control and passes
that to the callees, put_cmsg will write into this kernel buffer. If
people want to piggyback some information like timestamps upon returning
of sendmsg. ____sys_sendmsg will have to copy_to_user to the original buf,
which is not supported. As a result, users typically have to call recvmsg
on the ERRMSG_QUEUE of the socket, incurring extra system call overhead.

This commit supports put_cmsg to userspace in TX path by storing user
msg_control address in a new field in struct msghdr, and adding a new bit
flag use_msg_control_user_tx to toggle the behavior of put_cmsg. Thus,
it's possible to piggyback information in the msg_control of sendmsg.

Signed-off-by: Zijian Zhang <zijianzhang@bytedance.com>
Signed-off-by: Xiaochun Lu <xiaochun.lu@bytedance.com>
---
 include/linux/socket.h |  4 ++++
 net/compat.c           | 33 +++++++++++++++++++++++++--------
 net/core/scm.c         | 42 ++++++++++++++++++++++++++++++++----------
 net/socket.c           |  2 ++
 4 files changed, 63 insertions(+), 18 deletions(-)

Comments

Willem de Bruijn June 15, 2024, 11:42 a.m. UTC | #1
zijianzhang@ wrote:
> From: Zijian Zhang <zijianzhang@bytedance.com>
> 
> Since ____sys_sendmsg creates a kernel copy of msg_control and passes
> that to the callees, put_cmsg will write into this kernel buffer. If
> people want to piggyback some information like timestamps upon returning
> of sendmsg. ____sys_sendmsg will have to copy_to_user to the original buf,
> which is not supported. As a result, users typically have to call recvmsg
> on the ERRMSG_QUEUE of the socket, incurring extra system call overhead.
> 
> This commit supports put_cmsg to userspace in TX path by storing user
> msg_control address in a new field in struct msghdr, and adding a new bit
> flag use_msg_control_user_tx to toggle the behavior of put_cmsg. Thus,
> it's possible to piggyback information in the msg_control of sendmsg.
> 
> Signed-off-by: Zijian Zhang <zijianzhang@bytedance.com>
> Signed-off-by: Xiaochun Lu <xiaochun.lu@bytedance.com>
> ---
>  include/linux/socket.h |  4 ++++
>  net/compat.c           | 33 +++++++++++++++++++++++++--------
>  net/core/scm.c         | 42 ++++++++++++++++++++++++++++++++----------
>  net/socket.c           |  2 ++
>  4 files changed, 63 insertions(+), 18 deletions(-)
> 
> diff --git a/include/linux/socket.h b/include/linux/socket.h
> index 89d16b90370b..8d3db04f4a39 100644
> --- a/include/linux/socket.h
> +++ b/include/linux/socket.h
> @@ -71,9 +71,12 @@ struct msghdr {
>  		void __user	*msg_control_user;
>  	};
>  	bool		msg_control_is_user : 1;
> +	bool		use_msg_control_user_tx : 1;
>  	bool		msg_get_inq : 1;/* return INQ after receive */
>  	unsigned int	msg_flags;	/* flags on received message */
> +	void __user	*msg_control_user_tx;	/* msg_control_user in TX piggyback path */
>  	__kernel_size_t	msg_controllen;	/* ancillary data buffer length */
> +	__kernel_size_t msg_controllen_user_tx; /* msg_controllen in TX piggyback path */
>  	struct kiocb	*msg_iocb;	/* ptr to iocb for async requests */
>  	struct ubuf_info *msg_ubuf;
>  	int (*sg_from_iter)(struct sock *sk, struct sk_buff *skb,
> @@ -391,6 +394,7 @@ struct ucred {
>  
>  extern int move_addr_to_kernel(void __user *uaddr, int ulen, struct sockaddr_storage *kaddr);
>  extern int put_cmsg(struct msghdr*, int level, int type, int len, void *data);

> diff --git a/net/core/scm.c b/net/core/scm.c
> index 4f6a14babe5a..de70ff1981a1 100644
> --- a/net/core/scm.c
> +++ b/net/core/scm.c
> @@ -228,25 +228,29 @@ int __scm_send(struct socket *sock, struct msghdr *msg, struct scm_cookie *p)
>  }
>  EXPORT_SYMBOL(__scm_send);
>  
> -int put_cmsg(struct msghdr * msg, int level, int type, int len, void *data)
> +static int __put_cmsg(struct msghdr *msg, int level, int type, int len, void *data)
>  {
>  	int cmlen = CMSG_LEN(len);
> +	__kernel_size_t msg_controllen;
>  
> +	msg_controllen = msg->use_msg_control_user_tx ?
> +		msg->msg_controllen_user_tx : msg->msg_controllen;
>  	if (msg->msg_flags & MSG_CMSG_COMPAT)
>  		return put_cmsg_compat(msg, level, type, len, data);
>  
> -	if (!msg->msg_control || msg->msg_controllen < sizeof(struct cmsghdr)) {
> +	if (!msg->msg_control || msg_controllen < sizeof(struct cmsghdr)) {
>  		msg->msg_flags |= MSG_CTRUNC;
>  		return 0; /* XXX: return error? check spec. */
>  	}
> -	if (msg->msg_controllen < cmlen) {
> +	if (msg_controllen < cmlen) {
>  		msg->msg_flags |= MSG_CTRUNC;
> -		cmlen = msg->msg_controllen;
> +		cmlen = msg_controllen;
>  	}
>  
> -	if (msg->msg_control_is_user) {
> -		struct cmsghdr __user *cm = msg->msg_control_user;
> +	if (msg->use_msg_control_user_tx || msg->msg_control_is_user) {
> +		struct cmsghdr __user *cm;
>  
> +		cm = msg->msg_control_is_user ? msg->msg_control_user : msg->msg_control_user_tx;
>  		check_object_size(data, cmlen - sizeof(*cm), true);
>  
>  		if (!user_write_access_begin(cm, cmlen))
> @@ -267,12 +271,17 @@ int put_cmsg(struct msghdr * msg, int level, int type, int len, void *data)
>  		memcpy(CMSG_DATA(cm), data, cmlen - sizeof(*cm));
>  	}
>  
> -	cmlen = min(CMSG_SPACE(len), msg->msg_controllen);
> -	if (msg->msg_control_is_user)
> +	cmlen = min(CMSG_SPACE(len), msg_controllen);
> +	if (msg->msg_control_is_user) {
>  		msg->msg_control_user += cmlen;
> -	else
> +		msg->msg_controllen -= cmlen;
> +	} else if (msg->use_msg_control_user_tx) {
> +		msg->msg_control_user_tx += cmlen;
> +		msg->msg_controllen_user_tx -= cmlen;
> +	} else {
>  		msg->msg_control += cmlen;
> -	msg->msg_controllen -= cmlen;
> +		msg->msg_controllen -= cmlen;
> +	}
>  	return 0;
>  
>  efault_end:
> @@ -280,8 +289,21 @@ int put_cmsg(struct msghdr * msg, int level, int type, int len, void *data)
>  efault:
>  	return -EFAULT;
>  }
> +
> +int put_cmsg(struct msghdr *msg, int level, int type, int len, void *data)
> +{
> +	msg->use_msg_control_user_tx = false;
> +	return __put_cmsg(msg, level, type, len, data);
> +}
>  EXPORT_SYMBOL(put_cmsg);
>  
> +int put_cmsg_user_tx(struct msghdr *msg, int level, int type, int len, void *data)
> +{
> +	msg->use_msg_control_user_tx = true;
> +	return __put_cmsg(msg, level, type, len, data);
> +}
> +EXPORT_SYMBOL(put_cmsg_user_tx);
> +
>  void put_cmsg_scm_timestamping64(struct msghdr *msg, struct scm_timestamping_internal *tss_internal)
>  {
>  	struct scm_timestamping64 tss;
> diff --git a/net/socket.c b/net/socket.c
> index e416920e9399..2755bc7bef9c 100644
> --- a/net/socket.c
> +++ b/net/socket.c
> @@ -2561,6 +2561,8 @@ static int ____sys_sendmsg(struct socket *sock, struct msghdr *msg_sys,
>  		err = -EFAULT;
>  		if (copy_from_user(ctl_buf, msg_sys->msg_control_user, ctl_len))
>  			goto out_freectl;
> +		msg_sys->msg_control_user_tx = msg_sys->msg_control_user;
> +		msg_sys->msg_controllen_user_tx = msg_sys->msg_controllen;

No need for this separate user_tx pointer and put_cmsg_user_tx.

___sys_sendmsg copies the user data to a stack allocated kernel
buffer. All subsequent operations are on this buffer. __put_cmsg
already supports writing to this kernel buffer.

All that is needed is to copy_to_user the buffer on return from
__sock_sendmsg. And only if it should be copied, which the bit in
msghdr can signal.

>  		msg_sys->msg_control = ctl_buf;
>  		msg_sys->msg_control_is_user = false;
>  	}
> -- 
> 2.20.1
>
Zijian Zhang June 15, 2024, 7:06 p.m. UTC | #2
On 6/15/24 4:42 AM, Willem de Bruijn wrote:
> zijianzhang@ wrote:
>> From: Zijian Zhang <zijianzhang@bytedance.com>
>>
>> Since ____sys_sendmsg creates a kernel copy of msg_control and passes
>> that to the callees, put_cmsg will write into this kernel buffer. If
>> people want to piggyback some information like timestamps upon returning
>> of sendmsg. ____sys_sendmsg will have to copy_to_user to the original buf,
>> which is not supported. As a result, users typically have to call recvmsg
>> on the ERRMSG_QUEUE of the socket, incurring extra system call overhead.
>>
>> This commit supports put_cmsg to userspace in TX path by storing user
>> msg_control address in a new field in struct msghdr, and adding a new bit
>> flag use_msg_control_user_tx to toggle the behavior of put_cmsg. Thus,
>> it's possible to piggyback information in the msg_control of sendmsg.
>>
>> Signed-off-by: Zijian Zhang <zijianzhang@bytedance.com>
>> Signed-off-by: Xiaochun Lu <xiaochun.lu@bytedance.com>
>> ---
>>   include/linux/socket.h |  4 ++++
>>   net/compat.c           | 33 +++++++++++++++++++++++++--------
>>   net/core/scm.c         | 42 ++++++++++++++++++++++++++++++++----------
>>   net/socket.c           |  2 ++
>>   4 files changed, 63 insertions(+), 18 deletions(-)
>>
>> diff --git a/include/linux/socket.h b/include/linux/socket.h
>> index 89d16b90370b..8d3db04f4a39 100644
>> --- a/include/linux/socket.h
>> +++ b/include/linux/socket.h
>> @@ -71,9 +71,12 @@ struct msghdr {
>>   		void __user	*msg_control_user;
>>   	};
>>   	bool		msg_control_is_user : 1;
>> +	bool		use_msg_control_user_tx : 1;
>>   	bool		msg_get_inq : 1;/* return INQ after receive */
>>   	unsigned int	msg_flags;	/* flags on received message */
>> +	void __user	*msg_control_user_tx;	/* msg_control_user in TX piggyback path */
>>   	__kernel_size_t	msg_controllen;	/* ancillary data buffer length */
>> +	__kernel_size_t msg_controllen_user_tx; /* msg_controllen in TX piggyback path */
>>   	struct kiocb	*msg_iocb;	/* ptr to iocb for async requests */
>>   	struct ubuf_info *msg_ubuf;
>>   	int (*sg_from_iter)(struct sock *sk, struct sk_buff *skb,
>> @@ -391,6 +394,7 @@ struct ucred {
>>   
>>   extern int move_addr_to_kernel(void __user *uaddr, int ulen, struct sockaddr_storage *kaddr);
>>   extern int put_cmsg(struct msghdr*, int level, int type, int len, void *data);
> 
>> diff --git a/net/core/scm.c b/net/core/scm.c
>> index 4f6a14babe5a..de70ff1981a1 100644
>> --- a/net/core/scm.c
>> +++ b/net/core/scm.c
>> @@ -228,25 +228,29 @@ int __scm_send(struct socket *sock, struct msghdr *msg, struct scm_cookie *p)
>>   }
>>   EXPORT_SYMBOL(__scm_send);
>>   
>> -int put_cmsg(struct msghdr * msg, int level, int type, int len, void *data)
>> +static int __put_cmsg(struct msghdr *msg, int level, int type, int len, void *data)
>>   {
>>   	int cmlen = CMSG_LEN(len);
>> +	__kernel_size_t msg_controllen;
>>   
>> +	msg_controllen = msg->use_msg_control_user_tx ?
>> +		msg->msg_controllen_user_tx : msg->msg_controllen;
>>   	if (msg->msg_flags & MSG_CMSG_COMPAT)
>>   		return put_cmsg_compat(msg, level, type, len, data);
>>   
>> -	if (!msg->msg_control || msg->msg_controllen < sizeof(struct cmsghdr)) {
>> +	if (!msg->msg_control || msg_controllen < sizeof(struct cmsghdr)) {
>>   		msg->msg_flags |= MSG_CTRUNC;
>>   		return 0; /* XXX: return error? check spec. */
>>   	}
>> -	if (msg->msg_controllen < cmlen) {
>> +	if (msg_controllen < cmlen) {
>>   		msg->msg_flags |= MSG_CTRUNC;
>> -		cmlen = msg->msg_controllen;
>> +		cmlen = msg_controllen;
>>   	}
>>   
>> -	if (msg->msg_control_is_user) {
>> -		struct cmsghdr __user *cm = msg->msg_control_user;
>> +	if (msg->use_msg_control_user_tx || msg->msg_control_is_user) {
>> +		struct cmsghdr __user *cm;
>>   
>> +		cm = msg->msg_control_is_user ? msg->msg_control_user : msg->msg_control_user_tx;
>>   		check_object_size(data, cmlen - sizeof(*cm), true);
>>   
>>   		if (!user_write_access_begin(cm, cmlen))
>> @@ -267,12 +271,17 @@ int put_cmsg(struct msghdr * msg, int level, int type, int len, void *data)
>>   		memcpy(CMSG_DATA(cm), data, cmlen - sizeof(*cm));
>>   	}
>>   
>> -	cmlen = min(CMSG_SPACE(len), msg->msg_controllen);
>> -	if (msg->msg_control_is_user)
>> +	cmlen = min(CMSG_SPACE(len), msg_controllen);
>> +	if (msg->msg_control_is_user) {
>>   		msg->msg_control_user += cmlen;
>> -	else
>> +		msg->msg_controllen -= cmlen;
>> +	} else if (msg->use_msg_control_user_tx) {
>> +		msg->msg_control_user_tx += cmlen;
>> +		msg->msg_controllen_user_tx -= cmlen;
>> +	} else {
>>   		msg->msg_control += cmlen;
>> -	msg->msg_controllen -= cmlen;
>> +		msg->msg_controllen -= cmlen;
>> +	}
>>   	return 0;
>>   
>>   efault_end:
>> @@ -280,8 +289,21 @@ int put_cmsg(struct msghdr * msg, int level, int type, int len, void *data)
>>   efault:
>>   	return -EFAULT;
>>   }
>> +
>> +int put_cmsg(struct msghdr *msg, int level, int type, int len, void *data)
>> +{
>> +	msg->use_msg_control_user_tx = false;
>> +	return __put_cmsg(msg, level, type, len, data);
>> +}
>>   EXPORT_SYMBOL(put_cmsg);
>>   
>> +int put_cmsg_user_tx(struct msghdr *msg, int level, int type, int len, void *data)
>> +{
>> +	msg->use_msg_control_user_tx = true;
>> +	return __put_cmsg(msg, level, type, len, data);
>> +}
>> +EXPORT_SYMBOL(put_cmsg_user_tx);
>> +
>>   void put_cmsg_scm_timestamping64(struct msghdr *msg, struct scm_timestamping_internal *tss_internal)
>>   {
>>   	struct scm_timestamping64 tss;
>> diff --git a/net/socket.c b/net/socket.c
>> index e416920e9399..2755bc7bef9c 100644
>> --- a/net/socket.c
>> +++ b/net/socket.c
>> @@ -2561,6 +2561,8 @@ static int ____sys_sendmsg(struct socket *sock, struct msghdr *msg_sys,
>>   		err = -EFAULT;
>>   		if (copy_from_user(ctl_buf, msg_sys->msg_control_user, ctl_len))
>>   			goto out_freectl;
>> +		msg_sys->msg_control_user_tx = msg_sys->msg_control_user;
>> +		msg_sys->msg_controllen_user_tx = msg_sys->msg_controllen;
> 
> No need for this separate user_tx pointer and put_cmsg_user_tx.
> 
> ___sys_sendmsg copies the user data to a stack allocated kernel
> buffer. All subsequent operations are on this buffer. __put_cmsg
> already supports writing to this kernel buffer.
> 
> All that is needed is to copy_to_user the buffer on return from
> __sock_sendmsg. And only if it should be copied, which the bit in
> msghdr can signal.
>
copy_to_user upon returning from __sock_sendmsg is clean, but we may
need to take compat into account.

put_cmsg has already handled compat cleanly, I am trying to reuse it.

Since msg_control_user is overwritten in ____sys_sendmsg to a kernel
stack buffer, I piggyback user_tx pointer for further use by 
put_cmsg_user_tx.

Or, upon returning of ____sys_sendmsg, we can set msg_control_user back
to user addr. And, for_each_cmsghdr, if cmsg_type == SCM_ZC_... we can
do put_cmsg?

>>   		msg_sys->msg_control = ctl_buf;
>>   		msg_sys->msg_control_is_user = false;
>>   	}
>> -- 
>> 2.20.1
>>
> 
>
diff mbox series

Patch

diff --git a/include/linux/socket.h b/include/linux/socket.h
index 89d16b90370b..8d3db04f4a39 100644
--- a/include/linux/socket.h
+++ b/include/linux/socket.h
@@ -71,9 +71,12 @@  struct msghdr {
 		void __user	*msg_control_user;
 	};
 	bool		msg_control_is_user : 1;
+	bool		use_msg_control_user_tx : 1;
 	bool		msg_get_inq : 1;/* return INQ after receive */
 	unsigned int	msg_flags;	/* flags on received message */
+	void __user	*msg_control_user_tx;	/* msg_control_user in TX piggyback path */
 	__kernel_size_t	msg_controllen;	/* ancillary data buffer length */
+	__kernel_size_t msg_controllen_user_tx; /* msg_controllen in TX piggyback path */
 	struct kiocb	*msg_iocb;	/* ptr to iocb for async requests */
 	struct ubuf_info *msg_ubuf;
 	int (*sg_from_iter)(struct sock *sk, struct sk_buff *skb,
@@ -391,6 +394,7 @@  struct ucred {
 
 extern int move_addr_to_kernel(void __user *uaddr, int ulen, struct sockaddr_storage *kaddr);
 extern int put_cmsg(struct msghdr*, int level, int type, int len, void *data);
+extern int put_cmsg_user_tx(struct msghdr *msg, int level, int type, int len, void *data);
 
 struct timespec64;
 struct __kernel_timespec;
diff --git a/net/compat.c b/net/compat.c
index 485db8ee9b28..ae9d78b1c18b 100644
--- a/net/compat.c
+++ b/net/compat.c
@@ -211,6 +211,8 @@  int cmsghdr_from_user_compat_to_kern(struct msghdr *kmsg, struct sock *sk,
 		goto Einval;
 
 	/* Ok, looks like we made it.  Hook it up and return success. */
+	kmsg->msg_control_user_tx = kmsg->msg_control_user;
+	kmsg->msg_controllen_user_tx = kcmlen;
 	kmsg->msg_control_is_user = false;
 	kmsg->msg_control = kcmsg_base;
 	kmsg->msg_controllen = kcmlen;
@@ -226,13 +228,22 @@  int cmsghdr_from_user_compat_to_kern(struct msghdr *kmsg, struct sock *sk,
 
 int put_cmsg_compat(struct msghdr *kmsg, int level, int type, int len, void *data)
 {
-	struct compat_cmsghdr __user *cm = (struct compat_cmsghdr __user *) kmsg->msg_control_user;
+	struct compat_cmsghdr __user *cm;
 	struct compat_cmsghdr cmhdr;
 	struct old_timeval32 ctv;
 	struct old_timespec32 cts[3];
+	compat_size_t msg_controllen;
 	int cmlen;
 
-	if (cm == NULL || kmsg->msg_controllen < sizeof(*cm)) {
+	if (kmsg->use_msg_control_user_tx) {
+		cm = (struct compat_cmsghdr __user *)kmsg->msg_control_user_tx;
+		msg_controllen = kmsg->msg_controllen_user_tx;
+	} else {
+		cm = (struct compat_cmsghdr __user *)kmsg->msg_control_user;
+		msg_controllen = kmsg->msg_controllen;
+	}
+
+	if (!cm || msg_controllen < sizeof(*cm)) {
 		kmsg->msg_flags |= MSG_CTRUNC;
 		return 0; /* XXX: return error? check spec. */
 	}
@@ -260,9 +271,9 @@  int put_cmsg_compat(struct msghdr *kmsg, int level, int type, int len, void *dat
 	}
 
 	cmlen = CMSG_COMPAT_LEN(len);
-	if (kmsg->msg_controllen < cmlen) {
+	if (msg_controllen < cmlen) {
 		kmsg->msg_flags |= MSG_CTRUNC;
-		cmlen = kmsg->msg_controllen;
+		cmlen = msg_controllen;
 	}
 	cmhdr.cmsg_level = level;
 	cmhdr.cmsg_type = type;
@@ -273,10 +284,16 @@  int put_cmsg_compat(struct msghdr *kmsg, int level, int type, int len, void *dat
 	if (copy_to_user(CMSG_COMPAT_DATA(cm), data, cmlen - sizeof(struct compat_cmsghdr)))
 		return -EFAULT;
 	cmlen = CMSG_COMPAT_SPACE(len);
-	if (kmsg->msg_controllen < cmlen)
-		cmlen = kmsg->msg_controllen;
-	kmsg->msg_control_user += cmlen;
-	kmsg->msg_controllen -= cmlen;
+	if (msg_controllen < cmlen)
+		cmlen = msg_controllen;
+
+	if (kmsg->use_msg_control_user_tx) {
+		kmsg->msg_control_user_tx += cmlen;
+		kmsg->msg_controllen_user_tx -= cmlen;
+	} else {
+		kmsg->msg_control_user += cmlen;
+		kmsg->msg_controllen -= cmlen;
+	}
 	return 0;
 }
 
diff --git a/net/core/scm.c b/net/core/scm.c
index 4f6a14babe5a..de70ff1981a1 100644
--- a/net/core/scm.c
+++ b/net/core/scm.c
@@ -228,25 +228,29 @@  int __scm_send(struct socket *sock, struct msghdr *msg, struct scm_cookie *p)
 }
 EXPORT_SYMBOL(__scm_send);
 
-int put_cmsg(struct msghdr * msg, int level, int type, int len, void *data)
+static int __put_cmsg(struct msghdr *msg, int level, int type, int len, void *data)
 {
 	int cmlen = CMSG_LEN(len);
+	__kernel_size_t msg_controllen;
 
+	msg_controllen = msg->use_msg_control_user_tx ?
+		msg->msg_controllen_user_tx : msg->msg_controllen;
 	if (msg->msg_flags & MSG_CMSG_COMPAT)
 		return put_cmsg_compat(msg, level, type, len, data);
 
-	if (!msg->msg_control || msg->msg_controllen < sizeof(struct cmsghdr)) {
+	if (!msg->msg_control || msg_controllen < sizeof(struct cmsghdr)) {
 		msg->msg_flags |= MSG_CTRUNC;
 		return 0; /* XXX: return error? check spec. */
 	}
-	if (msg->msg_controllen < cmlen) {
+	if (msg_controllen < cmlen) {
 		msg->msg_flags |= MSG_CTRUNC;
-		cmlen = msg->msg_controllen;
+		cmlen = msg_controllen;
 	}
 
-	if (msg->msg_control_is_user) {
-		struct cmsghdr __user *cm = msg->msg_control_user;
+	if (msg->use_msg_control_user_tx || msg->msg_control_is_user) {
+		struct cmsghdr __user *cm;
 
+		cm = msg->msg_control_is_user ? msg->msg_control_user : msg->msg_control_user_tx;
 		check_object_size(data, cmlen - sizeof(*cm), true);
 
 		if (!user_write_access_begin(cm, cmlen))
@@ -267,12 +271,17 @@  int put_cmsg(struct msghdr * msg, int level, int type, int len, void *data)
 		memcpy(CMSG_DATA(cm), data, cmlen - sizeof(*cm));
 	}
 
-	cmlen = min(CMSG_SPACE(len), msg->msg_controllen);
-	if (msg->msg_control_is_user)
+	cmlen = min(CMSG_SPACE(len), msg_controllen);
+	if (msg->msg_control_is_user) {
 		msg->msg_control_user += cmlen;
-	else
+		msg->msg_controllen -= cmlen;
+	} else if (msg->use_msg_control_user_tx) {
+		msg->msg_control_user_tx += cmlen;
+		msg->msg_controllen_user_tx -= cmlen;
+	} else {
 		msg->msg_control += cmlen;
-	msg->msg_controllen -= cmlen;
+		msg->msg_controllen -= cmlen;
+	}
 	return 0;
 
 efault_end:
@@ -280,8 +289,21 @@  int put_cmsg(struct msghdr * msg, int level, int type, int len, void *data)
 efault:
 	return -EFAULT;
 }
+
+int put_cmsg(struct msghdr *msg, int level, int type, int len, void *data)
+{
+	msg->use_msg_control_user_tx = false;
+	return __put_cmsg(msg, level, type, len, data);
+}
 EXPORT_SYMBOL(put_cmsg);
 
+int put_cmsg_user_tx(struct msghdr *msg, int level, int type, int len, void *data)
+{
+	msg->use_msg_control_user_tx = true;
+	return __put_cmsg(msg, level, type, len, data);
+}
+EXPORT_SYMBOL(put_cmsg_user_tx);
+
 void put_cmsg_scm_timestamping64(struct msghdr *msg, struct scm_timestamping_internal *tss_internal)
 {
 	struct scm_timestamping64 tss;
diff --git a/net/socket.c b/net/socket.c
index e416920e9399..2755bc7bef9c 100644
--- a/net/socket.c
+++ b/net/socket.c
@@ -2561,6 +2561,8 @@  static int ____sys_sendmsg(struct socket *sock, struct msghdr *msg_sys,
 		err = -EFAULT;
 		if (copy_from_user(ctl_buf, msg_sys->msg_control_user, ctl_len))
 			goto out_freectl;
+		msg_sys->msg_control_user_tx = msg_sys->msg_control_user;
+		msg_sys->msg_controllen_user_tx = msg_sys->msg_controllen;
 		msg_sys->msg_control = ctl_buf;
 		msg_sys->msg_control_is_user = false;
 	}