diff mbox series

[2/3] userfaultfd: convert to ->read_iter()

Message ID 20240402202524.1514963-3-axboe@kernel.dk (mailing list archive)
State New, archived
Headers show
Series Convert fs drivers to ->read_iter() | expand

Commit Message

Jens Axboe April 2, 2024, 8:18 p.m. UTC
Rather than use the older style ->read() hook, use ->read_iter() so that
userfaultfd can support both O_NONBLOCK and IOCB_NOWAIT for non-blocking
read attempts.

Split the fd setup into two parts, so that userfaultfd can mark the file
mode with FMODE_NOWAIT before installing it into the process table. With
that, we can also defer grabbing the mm until we know the rest will
succeed, as the fd isn't visible before then.

Signed-off-by: Jens Axboe <axboe@kernel.dk>
---
 fs/userfaultfd.c | 42 ++++++++++++++++++++++++++----------------
 1 file changed, 26 insertions(+), 16 deletions(-)

Comments

Christian Brauner April 3, 2024, 10:09 a.m. UTC | #1
On Tue, Apr 02, 2024 at 02:18:22PM -0600, Jens Axboe wrote:
> Rather than use the older style ->read() hook, use ->read_iter() so that
> userfaultfd can support both O_NONBLOCK and IOCB_NOWAIT for non-blocking
> read attempts.
> 
> Split the fd setup into two parts, so that userfaultfd can mark the file
> mode with FMODE_NOWAIT before installing it into the process table. With
> that, we can also defer grabbing the mm until we know the rest will
> succeed, as the fd isn't visible before then.
> 
> Signed-off-by: Jens Axboe <axboe@kernel.dk>
> ---
>  fs/userfaultfd.c | 42 ++++++++++++++++++++++++++----------------
>  1 file changed, 26 insertions(+), 16 deletions(-)
> 
> diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c
> index 60dcfafdc11a..7864c2dba858 100644
> --- a/fs/userfaultfd.c
> +++ b/fs/userfaultfd.c
> @@ -282,7 +282,7 @@ static inline bool userfaultfd_huge_must_wait(struct userfaultfd_ctx *ctx,
>  /*
>   * Verify the pagetables are still not ok after having reigstered into
>   * the fault_pending_wqh to avoid userland having to UFFDIO_WAKE any
> - * userfault that has already been resolved, if userfaultfd_read and
> + * userfault that has already been resolved, if userfaultfd_read_iter and
>   * UFFDIO_COPY|ZEROPAGE are being run simultaneously on two different
>   * threads.
>   */
> @@ -1177,34 +1177,34 @@ static ssize_t userfaultfd_ctx_read(struct userfaultfd_ctx *ctx, int no_wait,
>  	return ret;
>  }
>  
> -static ssize_t userfaultfd_read(struct file *file, char __user *buf,
> -				size_t count, loff_t *ppos)
> +static ssize_t userfaultfd_read_iter(struct kiocb *iocb, struct iov_iter *to)
>  {
> +	struct file *file = iocb->ki_filp;
>  	struct userfaultfd_ctx *ctx = file->private_data;
>  	ssize_t _ret, ret = 0;
>  	struct uffd_msg msg;
> -	int no_wait = file->f_flags & O_NONBLOCK;
>  	struct inode *inode = file_inode(file);
> +	bool no_wait;
>  
>  	if (!userfaultfd_is_initialized(ctx))
>  		return -EINVAL;
>  
> +	no_wait = file->f_flags & O_NONBLOCK || iocb->ki_flags & IOCB_NOWAIT;
>  	for (;;) {
> -		if (count < sizeof(msg))
> +		if (iov_iter_count(to) < sizeof(msg))
>  			return ret ? ret : -EINVAL;
>  		_ret = userfaultfd_ctx_read(ctx, no_wait, &msg, inode);
>  		if (_ret < 0)
>  			return ret ? ret : _ret;
> -		if (copy_to_user((__u64 __user *) buf, &msg, sizeof(msg)))
> +		_ret = copy_to_iter(&msg, sizeof(msg), to);
> +		if (_ret < 0)
>  			return ret ? ret : -EFAULT;
>  		ret += sizeof(msg);
> -		buf += sizeof(msg);
> -		count -= sizeof(msg);
>  		/*
>  		 * Allow to read more than one fault at time but only
>  		 * block if waiting for the very first one.
>  		 */
> -		no_wait = O_NONBLOCK;
> +		no_wait = true;
>  	}
>  }
>  
> @@ -2172,7 +2172,7 @@ static const struct file_operations userfaultfd_fops = {
>  #endif
>  	.release	= userfaultfd_release,
>  	.poll		= userfaultfd_poll,
> -	.read		= userfaultfd_read,
> +	.read_iter	= userfaultfd_read_iter,
>  	.unlocked_ioctl = userfaultfd_ioctl,
>  	.compat_ioctl	= compat_ptr_ioctl,
>  	.llseek		= noop_llseek,
> @@ -2192,6 +2192,7 @@ static void init_once_userfaultfd_ctx(void *mem)
>  static int new_userfaultfd(int flags)
>  {
>  	struct userfaultfd_ctx *ctx;
> +	struct file *file;
>  	int fd;
>  
>  	BUG_ON(!current->mm);
> @@ -2215,16 +2216,25 @@ static int new_userfaultfd(int flags)
>  	init_rwsem(&ctx->map_changing_lock);
>  	atomic_set(&ctx->mmap_changing, 0);
>  	ctx->mm = current->mm;
> -	/* prevent the mm struct to be freed */
> -	mmgrab(ctx->mm);
> +
> +	fd = get_unused_fd_flags(O_RDONLY | (flags & UFFD_SHARED_FCNTL_FLAGS));
> +	if (fd < 0)
> +		goto err_out;
>  
>  	/* Create a new inode so that the LSM can block the creation.  */
> -	fd = anon_inode_create_getfd("[userfaultfd]", &userfaultfd_fops, ctx,
> +	file = anon_inode_create_getfile("[userfaultfd]", &userfaultfd_fops, ctx,
>  			O_RDONLY | (flags & UFFD_SHARED_FCNTL_FLAGS), NULL);
> -	if (fd < 0) {
> -		mmdrop(ctx->mm);
> -		kmem_cache_free(userfaultfd_ctx_cachep, ctx);
> +	if (IS_ERR(file)) {
> +		fd = PTR_ERR(file);
> +		goto err_out;

You're leaking the fd you allocated above.

>  	}
> +	/* prevent the mm struct to be freed */
> +	mmgrab(ctx->mm);
> +	file->f_mode |= FMODE_NOWAIT;
> +	fd_install(fd, file);
> +	return fd;
> +err_out:
> +	kmem_cache_free(userfaultfd_ctx_cachep, ctx);
>  	return fd;
>  }
>  
> -- 
> 2.43.0
>
Jens Axboe April 3, 2024, 1:44 p.m. UTC | #2
On 4/3/24 4:09 AM, Christian Brauner wrote:
>> @@ -2215,16 +2216,25 @@ static int new_userfaultfd(int flags)
>>  	init_rwsem(&ctx->map_changing_lock);
>>  	atomic_set(&ctx->mmap_changing, 0);
>>  	ctx->mm = current->mm;
>> -	/* prevent the mm struct to be freed */
>> -	mmgrab(ctx->mm);
>> +
>> +	fd = get_unused_fd_flags(O_RDONLY | (flags & UFFD_SHARED_FCNTL_FLAGS));
>> +	if (fd < 0)
>> +		goto err_out;
>>  
>>  	/* Create a new inode so that the LSM can block the creation.  */
>> -	fd = anon_inode_create_getfd("[userfaultfd]", &userfaultfd_fops, ctx,
>> +	file = anon_inode_create_getfile("[userfaultfd]", &userfaultfd_fops, ctx,
>>  			O_RDONLY | (flags & UFFD_SHARED_FCNTL_FLAGS), NULL);
>> -	if (fd < 0) {
>> -		mmdrop(ctx->mm);
>> -		kmem_cache_free(userfaultfd_ctx_cachep, ctx);
>> +	if (IS_ERR(file)) {
>> +		fd = PTR_ERR(file);
>> +		goto err_out;
> 
> You're leaking the fd you allocated above.

Oops yes - thanks, fixed.
diff mbox series

Patch

diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c
index 60dcfafdc11a..7864c2dba858 100644
--- a/fs/userfaultfd.c
+++ b/fs/userfaultfd.c
@@ -282,7 +282,7 @@  static inline bool userfaultfd_huge_must_wait(struct userfaultfd_ctx *ctx,
 /*
  * Verify the pagetables are still not ok after having reigstered into
  * the fault_pending_wqh to avoid userland having to UFFDIO_WAKE any
- * userfault that has already been resolved, if userfaultfd_read and
+ * userfault that has already been resolved, if userfaultfd_read_iter and
  * UFFDIO_COPY|ZEROPAGE are being run simultaneously on two different
  * threads.
  */
@@ -1177,34 +1177,34 @@  static ssize_t userfaultfd_ctx_read(struct userfaultfd_ctx *ctx, int no_wait,
 	return ret;
 }
 
-static ssize_t userfaultfd_read(struct file *file, char __user *buf,
-				size_t count, loff_t *ppos)
+static ssize_t userfaultfd_read_iter(struct kiocb *iocb, struct iov_iter *to)
 {
+	struct file *file = iocb->ki_filp;
 	struct userfaultfd_ctx *ctx = file->private_data;
 	ssize_t _ret, ret = 0;
 	struct uffd_msg msg;
-	int no_wait = file->f_flags & O_NONBLOCK;
 	struct inode *inode = file_inode(file);
+	bool no_wait;
 
 	if (!userfaultfd_is_initialized(ctx))
 		return -EINVAL;
 
+	no_wait = file->f_flags & O_NONBLOCK || iocb->ki_flags & IOCB_NOWAIT;
 	for (;;) {
-		if (count < sizeof(msg))
+		if (iov_iter_count(to) < sizeof(msg))
 			return ret ? ret : -EINVAL;
 		_ret = userfaultfd_ctx_read(ctx, no_wait, &msg, inode);
 		if (_ret < 0)
 			return ret ? ret : _ret;
-		if (copy_to_user((__u64 __user *) buf, &msg, sizeof(msg)))
+		_ret = copy_to_iter(&msg, sizeof(msg), to);
+		if (_ret < 0)
 			return ret ? ret : -EFAULT;
 		ret += sizeof(msg);
-		buf += sizeof(msg);
-		count -= sizeof(msg);
 		/*
 		 * Allow to read more than one fault at time but only
 		 * block if waiting for the very first one.
 		 */
-		no_wait = O_NONBLOCK;
+		no_wait = true;
 	}
 }
 
@@ -2172,7 +2172,7 @@  static const struct file_operations userfaultfd_fops = {
 #endif
 	.release	= userfaultfd_release,
 	.poll		= userfaultfd_poll,
-	.read		= userfaultfd_read,
+	.read_iter	= userfaultfd_read_iter,
 	.unlocked_ioctl = userfaultfd_ioctl,
 	.compat_ioctl	= compat_ptr_ioctl,
 	.llseek		= noop_llseek,
@@ -2192,6 +2192,7 @@  static void init_once_userfaultfd_ctx(void *mem)
 static int new_userfaultfd(int flags)
 {
 	struct userfaultfd_ctx *ctx;
+	struct file *file;
 	int fd;
 
 	BUG_ON(!current->mm);
@@ -2215,16 +2216,25 @@  static int new_userfaultfd(int flags)
 	init_rwsem(&ctx->map_changing_lock);
 	atomic_set(&ctx->mmap_changing, 0);
 	ctx->mm = current->mm;
-	/* prevent the mm struct to be freed */
-	mmgrab(ctx->mm);
+
+	fd = get_unused_fd_flags(O_RDONLY | (flags & UFFD_SHARED_FCNTL_FLAGS));
+	if (fd < 0)
+		goto err_out;
 
 	/* Create a new inode so that the LSM can block the creation.  */
-	fd = anon_inode_create_getfd("[userfaultfd]", &userfaultfd_fops, ctx,
+	file = anon_inode_create_getfile("[userfaultfd]", &userfaultfd_fops, ctx,
 			O_RDONLY | (flags & UFFD_SHARED_FCNTL_FLAGS), NULL);
-	if (fd < 0) {
-		mmdrop(ctx->mm);
-		kmem_cache_free(userfaultfd_ctx_cachep, ctx);
+	if (IS_ERR(file)) {
+		fd = PTR_ERR(file);
+		goto err_out;
 	}
+	/* prevent the mm struct to be freed */
+	mmgrab(ctx->mm);
+	file->f_mode |= FMODE_NOWAIT;
+	fd_install(fd, file);
+	return fd;
+err_out:
+	kmem_cache_free(userfaultfd_ctx_cachep, ctx);
 	return fd;
 }