diff mbox series

[v3,for-next,06/13] RDMA/core: Use refcount_t instead of atomic_t on refcount of mcast_group

Message ID 1621925504-33019-7-git-send-email-liweihang@huawei.com (mailing list archive)
State Superseded
Headers show
Series RDMA: Use refcount_t for reference counting | expand

Commit Message

Weihang Li May 25, 2021, 6:51 a.m. UTC
The refcount_t API will WARN on underflow and overflow of a reference
counter, and avoid use-after-free risks. Increase refcount_t from 0 to 1 is
regarded as there is a risk about use-after-free. So it should be set to 1
directly during initialization.

Signed-off-by: Weihang Li <liweihang@huawei.com>
---
 drivers/infiniband/core/multicast.c | 17 +++++++++++------
 1 file changed, 11 insertions(+), 6 deletions(-)
diff mbox series

Patch

diff --git a/drivers/infiniband/core/multicast.c b/drivers/infiniband/core/multicast.c
index a236532..17abc21 100644
--- a/drivers/infiniband/core/multicast.c
+++ b/drivers/infiniband/core/multicast.c
@@ -103,7 +103,7 @@  struct mcast_group {
 	struct list_head	active_list;
 	struct mcast_member	*last_join;
 	int			members[NUM_JOIN_MEMBERSHIP_TYPES];
-	atomic_t		refcount;
+	refcount_t		refcount;
 	enum mcast_group_state	state;
 	struct ib_sa_query	*query;
 	u16			pkey_index;
@@ -188,7 +188,7 @@  static void release_group(struct mcast_group *group)
 	unsigned long flags;
 
 	spin_lock_irqsave(&port->lock, flags);
-	if (atomic_dec_and_test(&group->refcount)) {
+	if (refcount_dec_and_test(&group->refcount)) {
 		rb_erase(&group->node, &port->table);
 		spin_unlock_irqrestore(&port->lock, flags);
 		kfree(group);
@@ -212,7 +212,7 @@  static void queue_join(struct mcast_member *member)
 	list_add_tail(&member->list, &group->pending_list);
 	if (group->state == MCAST_IDLE) {
 		group->state = MCAST_BUSY;
-		atomic_inc(&group->refcount);
+		refcount_inc(&group->refcount);
 		queue_work(mcast_wq, &group->work);
 	}
 	spin_unlock_irqrestore(&group->lock, flags);
@@ -565,8 +565,11 @@  static struct mcast_group *acquire_group(struct mcast_port *port,
 	if (!is_mgid0) {
 		spin_lock_irqsave(&port->lock, flags);
 		group = mcast_find(port, mgid);
-		if (group)
+		if (group) {
+			refcount_inc(&group->refcount);
 			goto found;
+		}
+
 		spin_unlock_irqrestore(&port->lock, flags);
 	}
 
@@ -590,8 +593,10 @@  static struct mcast_group *acquire_group(struct mcast_port *port,
 		group = cur_group;
 	} else
 		refcount_inc(&port->refcount);
+
+	refcount_set(&group->refcount, 1);
+
 found:
-	atomic_inc(&group->refcount);
 	spin_unlock_irqrestore(&port->lock, flags);
 	return group;
 }
@@ -780,7 +785,7 @@  static void mcast_groups_event(struct mcast_port *port,
 		group = rb_entry(node, struct mcast_group, node);
 		spin_lock(&group->lock);
 		if (group->state == MCAST_IDLE) {
-			atomic_inc(&group->refcount);
+			refcount_inc(&group->refcount);
 			queue_work(mcast_wq, &group->work);
 		}
 		if (group->state != MCAST_GROUP_ERROR)