diff mbox series

[1/3] io_uring: move cancelations to be io_uring_task based

Message ID 20241103175108.76460-2-axboe@kernel.dk (mailing list archive)
State New
Headers show
Series Move io_kiocb from task_struct to io_uring_task | expand

Commit Message

Jens Axboe Nov. 3, 2024, 5:49 p.m. UTC
Right now the task_struct pointer is used as the key to match a task,
but in preparation for some io_kiocb changes, move it to using struct
io_uring_task instead. No functional changes intended in this patch.

Signed-off-by: Jens Axboe <axboe@kernel.dk>
---
 io_uring/futex.c     |  4 ++--
 io_uring/futex.h     |  4 ++--
 io_uring/io_uring.c  | 42 +++++++++++++++++++++---------------------
 io_uring/io_uring.h  |  2 +-
 io_uring/poll.c      |  4 ++--
 io_uring/poll.h      |  2 +-
 io_uring/timeout.c   |  8 ++++----
 io_uring/timeout.h   |  2 +-
 io_uring/uring_cmd.c |  4 ++--
 io_uring/uring_cmd.h |  2 +-
 io_uring/waitid.c    |  4 ++--
 io_uring/waitid.h    |  2 +-
 12 files changed, 40 insertions(+), 40 deletions(-)
diff mbox series

Patch

diff --git a/io_uring/futex.c b/io_uring/futex.c
index 914848f46beb..e29662f039e1 100644
--- a/io_uring/futex.c
+++ b/io_uring/futex.c
@@ -141,7 +141,7 @@  int io_futex_cancel(struct io_ring_ctx *ctx, struct io_cancel_data *cd,
 	return -ENOENT;
 }
 
-bool io_futex_remove_all(struct io_ring_ctx *ctx, struct task_struct *task,
+bool io_futex_remove_all(struct io_ring_ctx *ctx, struct io_uring_task *tctx,
 			 bool cancel_all)
 {
 	struct hlist_node *tmp;
@@ -151,7 +151,7 @@  bool io_futex_remove_all(struct io_ring_ctx *ctx, struct task_struct *task,
 	lockdep_assert_held(&ctx->uring_lock);
 
 	hlist_for_each_entry_safe(req, tmp, &ctx->futex_list, hash_node) {
-		if (!io_match_task_safe(req, task, cancel_all))
+		if (!io_match_task_safe(req, tctx, cancel_all))
 			continue;
 		hlist_del_init(&req->hash_node);
 		__io_futex_cancel(ctx, req);
diff --git a/io_uring/futex.h b/io_uring/futex.h
index b8bb09873d57..d789fcf715e3 100644
--- a/io_uring/futex.h
+++ b/io_uring/futex.h
@@ -11,7 +11,7 @@  int io_futex_wake(struct io_kiocb *req, unsigned int issue_flags);
 #if defined(CONFIG_FUTEX)
 int io_futex_cancel(struct io_ring_ctx *ctx, struct io_cancel_data *cd,
 		    unsigned int issue_flags);
-bool io_futex_remove_all(struct io_ring_ctx *ctx, struct task_struct *task,
+bool io_futex_remove_all(struct io_ring_ctx *ctx, struct io_uring_task *tctx,
 			 bool cancel_all);
 bool io_futex_cache_init(struct io_ring_ctx *ctx);
 void io_futex_cache_free(struct io_ring_ctx *ctx);
@@ -23,7 +23,7 @@  static inline int io_futex_cancel(struct io_ring_ctx *ctx,
 	return 0;
 }
 static inline bool io_futex_remove_all(struct io_ring_ctx *ctx,
-				       struct task_struct *task, bool cancel_all)
+				       struct io_uring_task *tctx, bool cancel_all)
 {
 	return false;
 }
diff --git a/io_uring/io_uring.c b/io_uring/io_uring.c
index 5b421e67c031..701cbd4670d8 100644
--- a/io_uring/io_uring.c
+++ b/io_uring/io_uring.c
@@ -143,7 +143,7 @@  struct io_defer_entry {
 #define IO_CQ_WAKE_FORCE	(IO_CQ_WAKE_INIT >> 1)
 
 static bool io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
-					 struct task_struct *task,
+					 struct io_uring_task *tctx,
 					 bool cancel_all);
 
 static void io_queue_sqe(struct io_kiocb *req);
@@ -202,12 +202,12 @@  static bool io_match_linked(struct io_kiocb *head)
  * As io_match_task() but protected against racing with linked timeouts.
  * User must not hold timeout_lock.
  */
-bool io_match_task_safe(struct io_kiocb *head, struct task_struct *task,
+bool io_match_task_safe(struct io_kiocb *head, struct io_uring_task *tctx,
 			bool cancel_all)
 {
 	bool matched;
 
-	if (task && head->task != task)
+	if (tctx && head->task->io_uring != tctx)
 		return false;
 	if (cancel_all)
 		return true;
@@ -3286,7 +3286,7 @@  static int io_uring_release(struct inode *inode, struct file *file)
 }
 
 struct io_task_cancel {
-	struct task_struct *task;
+	struct io_uring_task *tctx;
 	bool all;
 };
 
@@ -3295,11 +3295,11 @@  static bool io_cancel_task_cb(struct io_wq_work *work, void *data)
 	struct io_kiocb *req = container_of(work, struct io_kiocb, work);
 	struct io_task_cancel *cancel = data;
 
-	return io_match_task_safe(req, cancel->task, cancel->all);
+	return io_match_task_safe(req, cancel->tctx, cancel->all);
 }
 
 static __cold bool io_cancel_defer_files(struct io_ring_ctx *ctx,
-					 struct task_struct *task,
+					 struct io_uring_task *tctx,
 					 bool cancel_all)
 {
 	struct io_defer_entry *de;
@@ -3307,7 +3307,7 @@  static __cold bool io_cancel_defer_files(struct io_ring_ctx *ctx,
 
 	spin_lock(&ctx->completion_lock);
 	list_for_each_entry_reverse(de, &ctx->defer_list, list) {
-		if (io_match_task_safe(de->req, task, cancel_all)) {
+		if (io_match_task_safe(de->req, tctx, cancel_all)) {
 			list_cut_position(&list, &ctx->defer_list, &de->list);
 			break;
 		}
@@ -3350,11 +3350,10 @@  static __cold bool io_uring_try_cancel_iowq(struct io_ring_ctx *ctx)
 }
 
 static __cold bool io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
-						struct task_struct *task,
+						struct io_uring_task *tctx,
 						bool cancel_all)
 {
-	struct io_task_cancel cancel = { .task = task, .all = cancel_all, };
-	struct io_uring_task *tctx = task ? task->io_uring : NULL;
+	struct io_task_cancel cancel = { .tctx = tctx, .all = cancel_all, };
 	enum io_wq_cancel cret;
 	bool ret = false;
 
@@ -3368,9 +3367,9 @@  static __cold bool io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
 	if (!ctx->rings)
 		return false;
 
-	if (!task) {
+	if (!tctx) {
 		ret |= io_uring_try_cancel_iowq(ctx);
-	} else if (tctx && tctx->io_wq) {
+	} else if (tctx->io_wq) {
 		/*
 		 * Cancels requests of all rings, not only @ctx, but
 		 * it's fine as the task is in exit/exec.
@@ -3393,15 +3392,15 @@  static __cold bool io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
 	if ((ctx->flags & IORING_SETUP_DEFER_TASKRUN) &&
 	    io_allowed_defer_tw_run(ctx))
 		ret |= io_run_local_work(ctx, INT_MAX) > 0;
-	ret |= io_cancel_defer_files(ctx, task, cancel_all);
+	ret |= io_cancel_defer_files(ctx, tctx, cancel_all);
 	mutex_lock(&ctx->uring_lock);
-	ret |= io_poll_remove_all(ctx, task, cancel_all);
-	ret |= io_waitid_remove_all(ctx, task, cancel_all);
-	ret |= io_futex_remove_all(ctx, task, cancel_all);
-	ret |= io_uring_try_cancel_uring_cmd(ctx, task, cancel_all);
+	ret |= io_poll_remove_all(ctx, tctx, cancel_all);
+	ret |= io_waitid_remove_all(ctx, tctx, cancel_all);
+	ret |= io_futex_remove_all(ctx, tctx, cancel_all);
+	ret |= io_uring_try_cancel_uring_cmd(ctx, tctx, cancel_all);
 	mutex_unlock(&ctx->uring_lock);
-	ret |= io_kill_timeouts(ctx, task, cancel_all);
-	if (task)
+	ret |= io_kill_timeouts(ctx, tctx, cancel_all);
+	if (tctx)
 		ret |= io_run_task_work() > 0;
 	else
 		ret |= flush_delayed_work(&ctx->fallback_work);
@@ -3454,12 +3453,13 @@  __cold void io_uring_cancel_generic(bool cancel_all, struct io_sq_data *sqd)
 				if (node->ctx->sq_data)
 					continue;
 				loop |= io_uring_try_cancel_requests(node->ctx,
-							current, cancel_all);
+							current->io_uring,
+							cancel_all);
 			}
 		} else {
 			list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
 				loop |= io_uring_try_cancel_requests(ctx,
-								     current,
+								     current->io_uring,
 								     cancel_all);
 		}
 
diff --git a/io_uring/io_uring.h b/io_uring/io_uring.h
index 52d15ac8d209..14d73a727320 100644
--- a/io_uring/io_uring.h
+++ b/io_uring/io_uring.h
@@ -116,7 +116,7 @@  void io_queue_next(struct io_kiocb *req);
 void io_task_refs_refill(struct io_uring_task *tctx);
 bool __io_alloc_req_refill(struct io_ring_ctx *ctx);
 
-bool io_match_task_safe(struct io_kiocb *head, struct task_struct *task,
+bool io_match_task_safe(struct io_kiocb *head, struct io_uring_task *tctx,
 			bool cancel_all);
 
 void io_activate_pollwq(struct io_ring_ctx *ctx);
diff --git a/io_uring/poll.c b/io_uring/poll.c
index 2d6698fb7400..7db3010b5733 100644
--- a/io_uring/poll.c
+++ b/io_uring/poll.c
@@ -714,7 +714,7 @@  int io_arm_poll_handler(struct io_kiocb *req, unsigned issue_flags)
 /*
  * Returns true if we found and killed one or more poll requests
  */
-__cold bool io_poll_remove_all(struct io_ring_ctx *ctx, struct task_struct *tsk,
+__cold bool io_poll_remove_all(struct io_ring_ctx *ctx, struct io_uring_task *tctx,
 			       bool cancel_all)
 {
 	unsigned nr_buckets = 1U << ctx->cancel_table.hash_bits;
@@ -729,7 +729,7 @@  __cold bool io_poll_remove_all(struct io_ring_ctx *ctx, struct task_struct *tsk,
 		struct io_hash_bucket *hb = &ctx->cancel_table.hbs[i];
 
 		hlist_for_each_entry_safe(req, tmp, &hb->list, hash_node) {
-			if (io_match_task_safe(req, tsk, cancel_all)) {
+			if (io_match_task_safe(req, tctx, cancel_all)) {
 				hlist_del_init(&req->hash_node);
 				io_poll_cancel_req(req);
 				found = true;
diff --git a/io_uring/poll.h b/io_uring/poll.h
index b0e3745f5a29..04ede93113dc 100644
--- a/io_uring/poll.h
+++ b/io_uring/poll.h
@@ -40,7 +40,7 @@  struct io_cancel_data;
 int io_poll_cancel(struct io_ring_ctx *ctx, struct io_cancel_data *cd,
 		   unsigned issue_flags);
 int io_arm_poll_handler(struct io_kiocb *req, unsigned issue_flags);
-bool io_poll_remove_all(struct io_ring_ctx *ctx, struct task_struct *tsk,
+bool io_poll_remove_all(struct io_ring_ctx *ctx, struct io_uring_task *tctx,
 			bool cancel_all);
 
 void io_poll_task_func(struct io_kiocb *req, struct io_tw_state *ts);
diff --git a/io_uring/timeout.c b/io_uring/timeout.c
index ed6c74f1a475..31fbea366d43 100644
--- a/io_uring/timeout.c
+++ b/io_uring/timeout.c
@@ -643,13 +643,13 @@  void io_queue_linked_timeout(struct io_kiocb *req)
 	io_put_req(req);
 }
 
-static bool io_match_task(struct io_kiocb *head, struct task_struct *task,
+static bool io_match_task(struct io_kiocb *head, struct io_uring_task *tctx,
 			  bool cancel_all)
 	__must_hold(&head->ctx->timeout_lock)
 {
 	struct io_kiocb *req;
 
-	if (task && head->task != task)
+	if (tctx && head->task->io_uring != tctx)
 		return false;
 	if (cancel_all)
 		return true;
@@ -662,7 +662,7 @@  static bool io_match_task(struct io_kiocb *head, struct task_struct *task,
 }
 
 /* Returns true if we found and killed one or more timeouts */
-__cold bool io_kill_timeouts(struct io_ring_ctx *ctx, struct task_struct *tsk,
+__cold bool io_kill_timeouts(struct io_ring_ctx *ctx, struct io_uring_task *tctx,
 			     bool cancel_all)
 {
 	struct io_timeout *timeout, *tmp;
@@ -677,7 +677,7 @@  __cold bool io_kill_timeouts(struct io_ring_ctx *ctx, struct task_struct *tsk,
 	list_for_each_entry_safe(timeout, tmp, &ctx->timeout_list, list) {
 		struct io_kiocb *req = cmd_to_io_kiocb(timeout);
 
-		if (io_match_task(req, tsk, cancel_all) &&
+		if (io_match_task(req, tctx, cancel_all) &&
 		    io_kill_timeout(req, -ECANCELED))
 			canceled++;
 	}
diff --git a/io_uring/timeout.h b/io_uring/timeout.h
index a6939f18313e..e91b32448dcf 100644
--- a/io_uring/timeout.h
+++ b/io_uring/timeout.h
@@ -24,7 +24,7 @@  static inline struct io_kiocb *io_disarm_linked_timeout(struct io_kiocb *req)
 __cold void io_flush_timeouts(struct io_ring_ctx *ctx);
 struct io_cancel_data;
 int io_timeout_cancel(struct io_ring_ctx *ctx, struct io_cancel_data *cd);
-__cold bool io_kill_timeouts(struct io_ring_ctx *ctx, struct task_struct *tsk,
+__cold bool io_kill_timeouts(struct io_ring_ctx *ctx, struct io_uring_task *tctx,
 			     bool cancel_all);
 void io_queue_linked_timeout(struct io_kiocb *req);
 void io_disarm_next(struct io_kiocb *req);
diff --git a/io_uring/uring_cmd.c b/io_uring/uring_cmd.c
index 88a73d21fc0b..f88fbc9869d0 100644
--- a/io_uring/uring_cmd.c
+++ b/io_uring/uring_cmd.c
@@ -47,7 +47,7 @@  static void io_req_uring_cleanup(struct io_kiocb *req, unsigned int issue_flags)
 }
 
 bool io_uring_try_cancel_uring_cmd(struct io_ring_ctx *ctx,
-				   struct task_struct *task, bool cancel_all)
+				   struct io_uring_task *tctx, bool cancel_all)
 {
 	struct hlist_node *tmp;
 	struct io_kiocb *req;
@@ -61,7 +61,7 @@  bool io_uring_try_cancel_uring_cmd(struct io_ring_ctx *ctx,
 				struct io_uring_cmd);
 		struct file *file = req->file;
 
-		if (!cancel_all && req->task != task)
+		if (!cancel_all && req->task->io_uring != tctx)
 			continue;
 
 		if (cmd->flags & IORING_URING_CMD_CANCELABLE) {
diff --git a/io_uring/uring_cmd.h b/io_uring/uring_cmd.h
index a361f98664d2..7dba0f1efc58 100644
--- a/io_uring/uring_cmd.h
+++ b/io_uring/uring_cmd.h
@@ -8,4 +8,4 @@  int io_uring_cmd(struct io_kiocb *req, unsigned int issue_flags);
 int io_uring_cmd_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe);
 
 bool io_uring_try_cancel_uring_cmd(struct io_ring_ctx *ctx,
-				   struct task_struct *task, bool cancel_all);
+				   struct io_uring_task *tctx, bool cancel_all);
diff --git a/io_uring/waitid.c b/io_uring/waitid.c
index 6362ec20abc0..9b7c23f96c47 100644
--- a/io_uring/waitid.c
+++ b/io_uring/waitid.c
@@ -184,7 +184,7 @@  int io_waitid_cancel(struct io_ring_ctx *ctx, struct io_cancel_data *cd,
 	return -ENOENT;
 }
 
-bool io_waitid_remove_all(struct io_ring_ctx *ctx, struct task_struct *task,
+bool io_waitid_remove_all(struct io_ring_ctx *ctx, struct io_uring_task *tctx,
 			  bool cancel_all)
 {
 	struct hlist_node *tmp;
@@ -194,7 +194,7 @@  bool io_waitid_remove_all(struct io_ring_ctx *ctx, struct task_struct *task,
 	lockdep_assert_held(&ctx->uring_lock);
 
 	hlist_for_each_entry_safe(req, tmp, &ctx->waitid_list, hash_node) {
-		if (!io_match_task_safe(req, task, cancel_all))
+		if (!io_match_task_safe(req, tctx, cancel_all))
 			continue;
 		hlist_del_init(&req->hash_node);
 		__io_waitid_cancel(ctx, req);
diff --git a/io_uring/waitid.h b/io_uring/waitid.h
index 956a8adafe8c..d5544aaf302a 100644
--- a/io_uring/waitid.h
+++ b/io_uring/waitid.h
@@ -11,5 +11,5 @@  int io_waitid_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe);
 int io_waitid(struct io_kiocb *req, unsigned int issue_flags);
 int io_waitid_cancel(struct io_ring_ctx *ctx, struct io_cancel_data *cd,
 		     unsigned int issue_flags);
-bool io_waitid_remove_all(struct io_ring_ctx *ctx, struct task_struct *task,
+bool io_waitid_remove_all(struct io_ring_ctx *ctx, struct io_uring_task *tctx,
 			  bool cancel_all);