diff mbox series

[v3,01/28] net/socket.c: switch to CLASS(fd)

Message ID 20241102050827.2451599-1-viro@zeniv.linux.org.uk (mailing list archive)
State New
Headers show
Series [v3,01/28] net/socket.c: switch to CLASS(fd) | expand

Commit Message

Al Viro Nov. 2, 2024, 5:07 a.m. UTC
The important part in sockfd_lookup_light() is avoiding needless
file refcount operations, not the marginal reduction of the register
pressure from not keeping a struct file pointer in the caller.

	Switch to use fdget()/fdpu(); with sane use of CLASS(fd) we can
get a better code generation...

	Would be nice if somebody tested it on networking test suites
(including benchmarks)...

	sockfd_lookup_light() does fdget(), uses sock_from_file() to
get the associated socket and returns the struct socket reference to
the caller, along with "do we need to fput()" flag.  No matching fdput(),
the caller does its equivalent manually, using the fact that sock->file
points to the struct file the socket has come from.

	Get rid of that - have the callers do fdget()/fdput() and
use sock_from_file() directly.  That kills sockfd_lookup_light()
and fput_light() (no users left).

	What's more, we can get rid of explicit fdget()/fdput() by
switching to CLASS(fd, ...) - code generation does not suffer, since
now fdput() inserted on "descriptor is not opened" failure exit
is recognized to be a no-op by compiler.

Reviewed-by: Christian Brauner <brauner@kernel.org>
Signed-off-by: Al Viro <viro@zeniv.linux.org.uk>
---
 include/linux/file.h |   6 -
 net/socket.c         | 303 +++++++++++++++++++------------------------
 2 files changed, 137 insertions(+), 172 deletions(-)

Comments

Simon Horman Nov. 2, 2024, 12:21 p.m. UTC | #1
On Sat, Nov 02, 2024 at 05:07:59AM +0000, Al Viro wrote:
> 	The important part in sockfd_lookup_light() is avoiding needless
> file refcount operations, not the marginal reduction of the register
> pressure from not keeping a struct file pointer in the caller.
> 
> 	Switch to use fdget()/fdpu(); with sane use of CLASS(fd) we can
> get a better code generation...
> 
> 	Would be nice if somebody tested it on networking test suites
> (including benchmarks)...
> 
> 	sockfd_lookup_light() does fdget(), uses sock_from_file() to
> get the associated socket and returns the struct socket reference to
> the caller, along with "do we need to fput()" flag.  No matching fdput(),
> the caller does its equivalent manually, using the fact that sock->file
> points to the struct file the socket has come from.
> 
> 	Get rid of that - have the callers do fdget()/fdput() and
> use sock_from_file() directly.  That kills sockfd_lookup_light()
> and fput_light() (no users left).
> 
> 	What's more, we can get rid of explicit fdget()/fdput() by
> switching to CLASS(fd, ...) - code generation does not suffer, since
> now fdput() inserted on "descriptor is not opened" failure exit
> is recognized to be a no-op by compiler.
> 
> Reviewed-by: Christian Brauner <brauner@kernel.org>
> Signed-off-by: Al Viro <viro@zeniv.linux.org.uk>

...

> diff --git a/net/socket.c b/net/socket.c

...

> @@ -2926,16 +2900,18 @@ static int do_recvmmsg(int fd, struct mmsghdr __user *mmsg,
>  
>  	datagrams = 0;
>  
> -	sock = sockfd_lookup_light(fd, &err, &fput_needed);
> -	if (!sock)
> -		return err;
> +	CLASS(fd, f)(fd);
> +
> +	if (fd_empty(f))
> +		return -EBADF;
> +	sock = sock_from_file(fd_file(f));
> +	if (unlikely(!sock))
> +		return -ENOTSOCK;

Hi Al,

There is an unconditional check on err down on line 2977.
However, with the above change err is now only conditionally
set before we reach that line. Are you sure that it will always
be initialised by the time line 2977 is reached?

...
Al Viro Nov. 3, 2024, 6:31 a.m. UTC | #2
On Sat, Nov 02, 2024 at 12:21:32PM +0000, Simon Horman wrote:

> > @@ -2926,16 +2900,18 @@ static int do_recvmmsg(int fd, struct mmsghdr __user *mmsg,
> >  
> >  	datagrams = 0;
> >  
> > -	sock = sockfd_lookup_light(fd, &err, &fput_needed);
> > -	if (!sock)
> > -		return err;
> > +	CLASS(fd, f)(fd);
> > +
> > +	if (fd_empty(f))
> > +		return -EBADF;
> > +	sock = sock_from_file(fd_file(f));
> > +	if (unlikely(!sock))
> > +		return -ENOTSOCK;
> 
> Hi Al,
> 
> There is an unconditional check on err down on line 2977.
> However, with the above change err is now only conditionally
> set before we reach that line. Are you sure that it will always
> be initialised by the time line 2977 is reached?

Nice catch, thank you.  It is possible, if you call recvmmsg(2) with
zero vlen and MSG_ERRQUEUE in flags.  Which is not going to be done in
any well-behaving code, making it really nasty - nothing like a kernel
bug that shows up only when trying to narrow down a userland bug upstream
of the syscall in question ;-/

AFAICS, that's the only bug of that sort in this commit - all other
places that used to rely upon successful sockfd_lookup_light() zeroing
err have an unconditional assignment to err shortly downstream of that.

Fix folded into commit in question, branch force-pushed; incremental follows
(I would rather not spam the lists with repost of the entire patchset for
the sake of that):

diff --git a/net/socket.c b/net/socket.c
index fb3806a11f94..c3ac02d060c0 100644
--- a/net/socket.c
+++ b/net/socket.c
@@ -2885,7 +2885,7 @@ static int do_recvmmsg(int fd, struct mmsghdr __user *mmsg,
 			  unsigned int vlen, unsigned int flags,
 			  struct timespec64 *timeout)
 {
-	int err, datagrams;
+	int err = 0, datagrams;
 	struct socket *sock;
 	struct mmsghdr __user *entry;
 	struct compat_mmsghdr __user *compat_entry;
Simon Horman Nov. 6, 2024, 10:03 a.m. UTC | #3
On Sun, Nov 03, 2024 at 06:31:13AM +0000, Al Viro wrote:
> On Sat, Nov 02, 2024 at 12:21:32PM +0000, Simon Horman wrote:
> 
> > > @@ -2926,16 +2900,18 @@ static int do_recvmmsg(int fd, struct mmsghdr __user *mmsg,
> > >  
> > >  	datagrams = 0;
> > >  
> > > -	sock = sockfd_lookup_light(fd, &err, &fput_needed);
> > > -	if (!sock)
> > > -		return err;
> > > +	CLASS(fd, f)(fd);
> > > +
> > > +	if (fd_empty(f))
> > > +		return -EBADF;
> > > +	sock = sock_from_file(fd_file(f));
> > > +	if (unlikely(!sock))
> > > +		return -ENOTSOCK;
> > 
> > Hi Al,
> > 
> > There is an unconditional check on err down on line 2977.
> > However, with the above change err is now only conditionally
> > set before we reach that line. Are you sure that it will always
> > be initialised by the time line 2977 is reached?
> 
> Nice catch, thank you.  It is possible, if you call recvmmsg(2) with
> zero vlen and MSG_ERRQUEUE in flags.  Which is not going to be done in
> any well-behaving code, making it really nasty - nothing like a kernel
> bug that shows up only when trying to narrow down a userland bug upstream
> of the syscall in question ;-/

Ouch.

> AFAICS, that's the only bug of that sort in this commit - all other
> places that used to rely upon successful sockfd_lookup_light() zeroing
> err have an unconditional assignment to err shortly downstream of that.

Yes, I was unable to find any others either.

> 
> Fix folded into commit in question, branch force-pushed; incremental follows
> (I would rather not spam the lists with repost of the entire patchset for
> the sake of that):

Thanks, will run some checks for good measure.
But the fix below looks good to me.

> diff --git a/net/socket.c b/net/socket.c
> index fb3806a11f94..c3ac02d060c0 100644
> --- a/net/socket.c
> +++ b/net/socket.c
> @@ -2885,7 +2885,7 @@ static int do_recvmmsg(int fd, struct mmsghdr __user *mmsg,
>  			  unsigned int vlen, unsigned int flags,
>  			  struct timespec64 *timeout)
>  {
> -	int err, datagrams;
> +	int err = 0, datagrams;
>  	struct socket *sock;
>  	struct mmsghdr __user *entry;
>  	struct compat_mmsghdr __user *compat_entry;
>
diff mbox series

Patch

diff --git a/include/linux/file.h b/include/linux/file.h
index f98de143245a..b49a92295b3f 100644
--- a/include/linux/file.h
+++ b/include/linux/file.h
@@ -30,12 +30,6 @@  extern struct file *alloc_file_pseudo_noaccount(struct inode *, struct vfsmount
 extern struct file *alloc_file_clone(struct file *, int flags,
 	const struct file_operations *);
 
-static inline void fput_light(struct file *file, int fput_needed)
-{
-	if (fput_needed)
-		fput(file);
-}
-
 /* either a reference to struct file + flags
  * (cloned vs. borrowed, pos locked), with
  * flags stored in lower bits of value,
diff --git a/net/socket.c b/net/socket.c
index 601ad74930ef..fb3806a11f94 100644
--- a/net/socket.c
+++ b/net/socket.c
@@ -509,7 +509,7 @@  static int sock_map_fd(struct socket *sock, int flags)
 
 struct socket *sock_from_file(struct file *file)
 {
-	if (file->f_op == &socket_file_ops)
+	if (likely(file->f_op == &socket_file_ops))
 		return file->private_data;	/* set in sock_alloc_file */
 
 	return NULL;
@@ -549,24 +549,6 @@  struct socket *sockfd_lookup(int fd, int *err)
 }
 EXPORT_SYMBOL(sockfd_lookup);
 
-static struct socket *sockfd_lookup_light(int fd, int *err, int *fput_needed)
-{
-	struct fd f = fdget(fd);
-	struct socket *sock;
-
-	*err = -EBADF;
-	if (fd_file(f)) {
-		sock = sock_from_file(fd_file(f));
-		if (likely(sock)) {
-			*fput_needed = f.word & FDPUT_FPUT;
-			return sock;
-		}
-		*err = -ENOTSOCK;
-		fdput(f);
-	}
-	return NULL;
-}
-
 static ssize_t sockfs_listxattr(struct dentry *dentry, char *buffer,
 				size_t size)
 {
@@ -1853,16 +1835,20 @@  int __sys_bind(int fd, struct sockaddr __user *umyaddr, int addrlen)
 {
 	struct socket *sock;
 	struct sockaddr_storage address;
-	int err, fput_needed;
-
-	sock = sockfd_lookup_light(fd, &err, &fput_needed);
-	if (sock) {
-		err = move_addr_to_kernel(umyaddr, addrlen, &address);
-		if (!err)
-			err = __sys_bind_socket(sock, &address, addrlen);
-		fput_light(sock->file, fput_needed);
-	}
-	return err;
+	CLASS(fd, f)(fd);
+	int err;
+
+	if (fd_empty(f))
+		return -EBADF;
+	sock = sock_from_file(fd_file(f));
+	if (unlikely(!sock))
+		return -ENOTSOCK;
+
+	err = move_addr_to_kernel(umyaddr, addrlen, &address);
+	if (unlikely(err))
+		return err;
+
+	return __sys_bind_socket(sock, &address, addrlen);
 }
 
 SYSCALL_DEFINE3(bind, int, fd, struct sockaddr __user *, umyaddr, int, addrlen)
@@ -1891,15 +1877,16 @@  int __sys_listen_socket(struct socket *sock, int backlog)
 
 int __sys_listen(int fd, int backlog)
 {
+	CLASS(fd, f)(fd);
 	struct socket *sock;
-	int err, fput_needed;
 
-	sock = sockfd_lookup_light(fd, &err, &fput_needed);
-	if (sock) {
-		err = __sys_listen_socket(sock, backlog);
-		fput_light(sock->file, fput_needed);
-	}
-	return err;
+	if (fd_empty(f))
+		return -EBADF;
+	sock = sock_from_file(fd_file(f));
+	if (unlikely(!sock))
+		return -ENOTSOCK;
+
+	return __sys_listen_socket(sock, backlog);
 }
 
 SYSCALL_DEFINE2(listen, int, fd, int, backlog)
@@ -2009,17 +1996,12 @@  static int __sys_accept4_file(struct file *file, struct sockaddr __user *upeer_s
 int __sys_accept4(int fd, struct sockaddr __user *upeer_sockaddr,
 		  int __user *upeer_addrlen, int flags)
 {
-	int ret = -EBADF;
-	struct fd f;
+	CLASS(fd, f)(fd);
 
-	f = fdget(fd);
-	if (fd_file(f)) {
-		ret = __sys_accept4_file(fd_file(f), upeer_sockaddr,
+	if (fd_empty(f))
+		return -EBADF;
+	return __sys_accept4_file(fd_file(f), upeer_sockaddr,
 					 upeer_addrlen, flags);
-		fdput(f);
-	}
-
-	return ret;
 }
 
 SYSCALL_DEFINE4(accept4, int, fd, struct sockaddr __user *, upeer_sockaddr,
@@ -2071,20 +2053,18 @@  int __sys_connect_file(struct file *file, struct sockaddr_storage *address,
 
 int __sys_connect(int fd, struct sockaddr __user *uservaddr, int addrlen)
 {
-	int ret = -EBADF;
-	struct fd f;
+	struct sockaddr_storage address;
+	CLASS(fd, f)(fd);
+	int ret;
 
-	f = fdget(fd);
-	if (fd_file(f)) {
-		struct sockaddr_storage address;
+	if (fd_empty(f))
+		return -EBADF;
 
-		ret = move_addr_to_kernel(uservaddr, addrlen, &address);
-		if (!ret)
-			ret = __sys_connect_file(fd_file(f), &address, addrlen, 0);
-		fdput(f);
-	}
+	ret = move_addr_to_kernel(uservaddr, addrlen, &address);
+	if (ret)
+		return ret;
 
-	return ret;
+	return __sys_connect_file(fd_file(f), &address, addrlen, 0);
 }
 
 SYSCALL_DEFINE3(connect, int, fd, struct sockaddr __user *, uservaddr,
@@ -2103,26 +2083,25 @@  int __sys_getsockname(int fd, struct sockaddr __user *usockaddr,
 {
 	struct socket *sock;
 	struct sockaddr_storage address;
-	int err, fput_needed;
+	CLASS(fd, f)(fd);
+	int err;
 
-	sock = sockfd_lookup_light(fd, &err, &fput_needed);
-	if (!sock)
-		goto out;
+	if (fd_empty(f))
+		return -EBADF;
+	sock = sock_from_file(fd_file(f));
+	if (unlikely(!sock))
+		return -ENOTSOCK;
 
 	err = security_socket_getsockname(sock);
 	if (err)
-		goto out_put;
+		return err;
 
 	err = READ_ONCE(sock->ops)->getname(sock, (struct sockaddr *)&address, 0);
 	if (err < 0)
-		goto out_put;
-	/* "err" is actually length in this case */
-	err = move_addr_to_user(&address, err, usockaddr, usockaddr_len);
+		return err;
 
-out_put:
-	fput_light(sock->file, fput_needed);
-out:
-	return err;
+	/* "err" is actually length in this case */
+	return move_addr_to_user(&address, err, usockaddr, usockaddr_len);
 }
 
 SYSCALL_DEFINE3(getsockname, int, fd, struct sockaddr __user *, usockaddr,
@@ -2141,26 +2120,25 @@  int __sys_getpeername(int fd, struct sockaddr __user *usockaddr,
 {
 	struct socket *sock;
 	struct sockaddr_storage address;
-	int err, fput_needed;
+	CLASS(fd, f)(fd);
+	int err;
 
-	sock = sockfd_lookup_light(fd, &err, &fput_needed);
-	if (sock != NULL) {
-		const struct proto_ops *ops = READ_ONCE(sock->ops);
+	if (fd_empty(f))
+		return -EBADF;
+	sock = sock_from_file(fd_file(f));
+	if (unlikely(!sock))
+		return -ENOTSOCK;
 
-		err = security_socket_getpeername(sock);
-		if (err) {
-			fput_light(sock->file, fput_needed);
-			return err;
-		}
+	err = security_socket_getpeername(sock);
+	if (err)
+		return err;
 
-		err = ops->getname(sock, (struct sockaddr *)&address, 1);
-		if (err >= 0)
-			/* "err" is actually length in this case */
-			err = move_addr_to_user(&address, err, usockaddr,
-						usockaddr_len);
-		fput_light(sock->file, fput_needed);
-	}
-	return err;
+	err = READ_ONCE(sock->ops)->getname(sock, (struct sockaddr *)&address, 1);
+	if (err < 0)
+		return err;
+
+	/* "err" is actually length in this case */
+	return move_addr_to_user(&address, err, usockaddr, usockaddr_len);
 }
 
 SYSCALL_DEFINE3(getpeername, int, fd, struct sockaddr __user *, usockaddr,
@@ -2181,14 +2159,17 @@  int __sys_sendto(int fd, void __user *buff, size_t len, unsigned int flags,
 	struct sockaddr_storage address;
 	int err;
 	struct msghdr msg;
-	int fput_needed;
 
 	err = import_ubuf(ITER_SOURCE, buff, len, &msg.msg_iter);
 	if (unlikely(err))
 		return err;
-	sock = sockfd_lookup_light(fd, &err, &fput_needed);
-	if (!sock)
-		goto out;
+
+	CLASS(fd, f)(fd);
+	if (fd_empty(f))
+		return -EBADF;
+	sock = sock_from_file(fd_file(f));
+	if (unlikely(!sock))
+		return -ENOTSOCK;
 
 	msg.msg_name = NULL;
 	msg.msg_control = NULL;
@@ -2198,7 +2179,7 @@  int __sys_sendto(int fd, void __user *buff, size_t len, unsigned int flags,
 	if (addr) {
 		err = move_addr_to_kernel(addr, addr_len, &address);
 		if (err < 0)
-			goto out_put;
+			return err;
 		msg.msg_name = (struct sockaddr *)&address;
 		msg.msg_namelen = addr_len;
 	}
@@ -2206,12 +2187,7 @@  int __sys_sendto(int fd, void __user *buff, size_t len, unsigned int flags,
 	if (sock->file->f_flags & O_NONBLOCK)
 		flags |= MSG_DONTWAIT;
 	msg.msg_flags = flags;
-	err = __sock_sendmsg(sock, &msg);
-
-out_put:
-	fput_light(sock->file, fput_needed);
-out:
-	return err;
+	return __sock_sendmsg(sock, &msg);
 }
 
 SYSCALL_DEFINE6(sendto, int, fd, void __user *, buff, size_t, len,
@@ -2246,14 +2222,18 @@  int __sys_recvfrom(int fd, void __user *ubuf, size_t size, unsigned int flags,
 	};
 	struct socket *sock;
 	int err, err2;
-	int fput_needed;
 
 	err = import_ubuf(ITER_DEST, ubuf, size, &msg.msg_iter);
 	if (unlikely(err))
 		return err;
-	sock = sockfd_lookup_light(fd, &err, &fput_needed);
-	if (!sock)
-		goto out;
+
+	CLASS(fd, f)(fd);
+
+	if (fd_empty(f))
+		return -EBADF;
+	sock = sock_from_file(fd_file(f));
+	if (unlikely(!sock))
+		return -ENOTSOCK;
 
 	if (sock->file->f_flags & O_NONBLOCK)
 		flags |= MSG_DONTWAIT;
@@ -2265,9 +2245,6 @@  int __sys_recvfrom(int fd, void __user *ubuf, size_t size, unsigned int flags,
 		if (err2 < 0)
 			err = err2;
 	}
-
-	fput_light(sock->file, fput_needed);
-out:
 	return err;
 }
 
@@ -2342,17 +2319,16 @@  int __sys_setsockopt(int fd, int level, int optname, char __user *user_optval,
 {
 	sockptr_t optval = USER_SOCKPTR(user_optval);
 	bool compat = in_compat_syscall();
-	int err, fput_needed;
 	struct socket *sock;
+	CLASS(fd, f)(fd);
 
-	sock = sockfd_lookup_light(fd, &err, &fput_needed);
-	if (!sock)
-		return err;
-
-	err = do_sock_setsockopt(sock, compat, level, optname, optval, optlen);
+	if (fd_empty(f))
+		return -EBADF;
+	sock = sock_from_file(fd_file(f));
+	if (unlikely(!sock))
+		return -ENOTSOCK;
 
-	fput_light(sock->file, fput_needed);
-	return err;
+	return do_sock_setsockopt(sock, compat, level, optname, optval, optlen);
 }
 
 SYSCALL_DEFINE5(setsockopt, int, fd, int, level, int, optname,
@@ -2408,20 +2384,17 @@  EXPORT_SYMBOL(do_sock_getsockopt);
 int __sys_getsockopt(int fd, int level, int optname, char __user *optval,
 		int __user *optlen)
 {
-	int err, fput_needed;
 	struct socket *sock;
-	bool compat;
+	CLASS(fd, f)(fd);
 
-	sock = sockfd_lookup_light(fd, &err, &fput_needed);
-	if (!sock)
-		return err;
+	if (fd_empty(f))
+		return -EBADF;
+	sock = sock_from_file(fd_file(f));
+	if (unlikely(!sock))
+		return -ENOTSOCK;
 
-	compat = in_compat_syscall();
-	err = do_sock_getsockopt(sock, compat, level, optname,
+	return do_sock_getsockopt(sock, in_compat_syscall(), level, optname,
 				 USER_SOCKPTR(optval), USER_SOCKPTR(optlen));
-
-	fput_light(sock->file, fput_needed);
-	return err;
 }
 
 SYSCALL_DEFINE5(getsockopt, int, fd, int, level, int, optname,
@@ -2447,15 +2420,16 @@  int __sys_shutdown_sock(struct socket *sock, int how)
 
 int __sys_shutdown(int fd, int how)
 {
-	int err, fput_needed;
 	struct socket *sock;
+	CLASS(fd, f)(fd);
 
-	sock = sockfd_lookup_light(fd, &err, &fput_needed);
-	if (sock != NULL) {
-		err = __sys_shutdown_sock(sock, how);
-		fput_light(sock->file, fput_needed);
-	}
-	return err;
+	if (fd_empty(f))
+		return -EBADF;
+	sock = sock_from_file(fd_file(f));
+	if (unlikely(!sock))
+		return -ENOTSOCK;
+
+	return __sys_shutdown_sock(sock, how);
 }
 
 SYSCALL_DEFINE2(shutdown, int, fd, int, how)
@@ -2671,22 +2645,21 @@  long __sys_sendmsg_sock(struct socket *sock, struct msghdr *msg,
 long __sys_sendmsg(int fd, struct user_msghdr __user *msg, unsigned int flags,
 		   bool forbid_cmsg_compat)
 {
-	int fput_needed, err;
 	struct msghdr msg_sys;
 	struct socket *sock;
 
 	if (forbid_cmsg_compat && (flags & MSG_CMSG_COMPAT))
 		return -EINVAL;
 
-	sock = sockfd_lookup_light(fd, &err, &fput_needed);
-	if (!sock)
-		goto out;
+	CLASS(fd, f)(fd);
 
-	err = ___sys_sendmsg(sock, msg, &msg_sys, flags, NULL, 0);
+	if (fd_empty(f))
+		return -EBADF;
+	sock = sock_from_file(fd_file(f));
+	if (unlikely(!sock))
+		return -ENOTSOCK;
 
-	fput_light(sock->file, fput_needed);
-out:
-	return err;
+	return ___sys_sendmsg(sock, msg, &msg_sys, flags, NULL, 0);
 }
 
 SYSCALL_DEFINE3(sendmsg, int, fd, struct user_msghdr __user *, msg, unsigned int, flags)
@@ -2701,7 +2674,7 @@  SYSCALL_DEFINE3(sendmsg, int, fd, struct user_msghdr __user *, msg, unsigned int
 int __sys_sendmmsg(int fd, struct mmsghdr __user *mmsg, unsigned int vlen,
 		   unsigned int flags, bool forbid_cmsg_compat)
 {
-	int fput_needed, err, datagrams;
+	int err, datagrams;
 	struct socket *sock;
 	struct mmsghdr __user *entry;
 	struct compat_mmsghdr __user *compat_entry;
@@ -2717,9 +2690,13 @@  int __sys_sendmmsg(int fd, struct mmsghdr __user *mmsg, unsigned int vlen,
 
 	datagrams = 0;
 
-	sock = sockfd_lookup_light(fd, &err, &fput_needed);
-	if (!sock)
-		return err;
+	CLASS(fd, f)(fd);
+
+	if (fd_empty(f))
+		return -EBADF;
+	sock = sock_from_file(fd_file(f));
+	if (unlikely(!sock))
+		return -ENOTSOCK;
 
 	used_address.name_len = UINT_MAX;
 	entry = mmsg;
@@ -2756,8 +2733,6 @@  int __sys_sendmmsg(int fd, struct mmsghdr __user *mmsg, unsigned int vlen,
 		cond_resched();
 	}
 
-	fput_light(sock->file, fput_needed);
-
 	/* We only return an error if no datagrams were able to be sent */
 	if (datagrams != 0)
 		return datagrams;
@@ -2879,22 +2854,21 @@  long __sys_recvmsg_sock(struct socket *sock, struct msghdr *msg,
 long __sys_recvmsg(int fd, struct user_msghdr __user *msg, unsigned int flags,
 		   bool forbid_cmsg_compat)
 {
-	int fput_needed, err;
 	struct msghdr msg_sys;
 	struct socket *sock;
 
 	if (forbid_cmsg_compat && (flags & MSG_CMSG_COMPAT))
 		return -EINVAL;
 
-	sock = sockfd_lookup_light(fd, &err, &fput_needed);
-	if (!sock)
-		goto out;
+	CLASS(fd, f)(fd);
 
-	err = ___sys_recvmsg(sock, msg, &msg_sys, flags, 0);
+	if (fd_empty(f))
+		return -EBADF;
+	sock = sock_from_file(fd_file(f));
+	if (unlikely(!sock))
+		return -ENOTSOCK;
 
-	fput_light(sock->file, fput_needed);
-out:
-	return err;
+	return ___sys_recvmsg(sock, msg, &msg_sys, flags, 0);
 }
 
 SYSCALL_DEFINE3(recvmsg, int, fd, struct user_msghdr __user *, msg,
@@ -2911,7 +2885,7 @@  static int do_recvmmsg(int fd, struct mmsghdr __user *mmsg,
 			  unsigned int vlen, unsigned int flags,
 			  struct timespec64 *timeout)
 {
-	int fput_needed, err, datagrams;
+	int err, datagrams;
 	struct socket *sock;
 	struct mmsghdr __user *entry;
 	struct compat_mmsghdr __user *compat_entry;
@@ -2926,16 +2900,18 @@  static int do_recvmmsg(int fd, struct mmsghdr __user *mmsg,
 
 	datagrams = 0;
 
-	sock = sockfd_lookup_light(fd, &err, &fput_needed);
-	if (!sock)
-		return err;
+	CLASS(fd, f)(fd);
+
+	if (fd_empty(f))
+		return -EBADF;
+	sock = sock_from_file(fd_file(f));
+	if (unlikely(!sock))
+		return -ENOTSOCK;
 
 	if (likely(!(flags & MSG_ERRQUEUE))) {
 		err = sock_error(sock->sk);
-		if (err) {
-			datagrams = err;
-			goto out_put;
-		}
+		if (err)
+			return err;
 	}
 
 	entry = mmsg;
@@ -2992,12 +2968,10 @@  static int do_recvmmsg(int fd, struct mmsghdr __user *mmsg,
 	}
 
 	if (err == 0)
-		goto out_put;
+		return datagrams;
 
-	if (datagrams == 0) {
-		datagrams = err;
-		goto out_put;
-	}
+	if (datagrams == 0)
+		return err;
 
 	/*
 	 * We may return less entries than requested (vlen) if the
@@ -3012,9 +2986,6 @@  static int do_recvmmsg(int fd, struct mmsghdr __user *mmsg,
 		 */
 		WRITE_ONCE(sock->sk->sk_err, -err);
 	}
-out_put:
-	fput_light(sock->file, fput_needed);
-
 	return datagrams;
 }