diff mbox series

[net-next,v4,01/12] net: devlink: make sure that devlink_try_get() works with valid pointer during xarray iteration

Message ID 20220725082925.366455-2-jiri@resnulli.us (mailing list archive)
State Accepted
Commit 30bab7cdb56da4819ff081ad658646f2df16c098
Delegated to: Netdev Maintainers
Headers show
Series Implement dev info and dev flash for line cards | expand

Checks

Context Check Description
netdev/tree_selection success Clearly marked for net-next
netdev/fixes_present success Fixes tag not required for -next series
netdev/subject_prefix success Link
netdev/cover_letter success Series has a cover letter
netdev/patch_count success Link
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 2 this patch: 2
netdev/cc_maintainers success CCed 6 of 6 maintainers
netdev/build_clang success Errors and warnings before: 5 this patch: 5
netdev/module_param success Was 0 now: 0
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success No Fixes tag
netdev/build_allmodconfig_warn success Errors and warnings before: 2 this patch: 2
netdev/checkpatch warning WARNING: line length of 81 exceeds 80 columns
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/source_inline success Was 0 now: 0

Commit Message

Jiri Pirko July 25, 2022, 8:29 a.m. UTC
From: Jiri Pirko <jiri@nvidia.com>

Remove dependency on devlink_mutex during devlinks xarray iteration.

The reason is that devlink_register/unregister() functions taking
devlink_mutex would deadlock during devlink reload operation of devlink
instance which registers/unregisters nested devlink instances.

The devlinks xarray consistency is ensured internally by xarray.
There is a reference taken when working with devlink using
devlink_try_get(). But there is no guarantee that devlink pointer
picked during xarray iteration is not freed before devlink_try_get()
is called.

Make sure that devlink_try_get() works with valid pointer.
Achieve it by:
1) Splitting devlink_put() so the completion is sent only
   after grace period. Completion unblocks the devlink_unregister()
   routine, which is followed-up by devlink_free()
2) During devlinks xa_array iteration, get devlink pointer from xa_array
   holding RCU read lock and taking reference using devlink_try_get()
   before unlock.

Signed-off-by: Jiri Pirko <jiri@nvidia.com>
---
v3->v4:
- introduced an iteration helpers and convert to use them instead of
  manually locking rcu_read_lock over xa_for_each_marked()
  and devlink_try_get() couple
- converted devlink_get_from_attrs() to take reference during iteration
  as well.
v2->v3:
- s/enf/end/ in devlink_put() comment
- added missing rcu_read_lock() call to info_get_dumpit()
- extended patch description by motivation
- removed an extra "by" from patch description
v1->v2:
- new patch (originally part of different patchset)
---
 net/core/devlink.c | 171 +++++++++++++++++++++------------------------
 1 file changed, 80 insertions(+), 91 deletions(-)

Comments

Jakub Kicinski July 26, 2022, 1:47 a.m. UTC | #1
On Mon, 25 Jul 2022 10:29:14 +0200 Jiri Pirko wrote:
> From: Jiri Pirko <jiri@nvidia.com>
> 
> Remove dependency on devlink_mutex during devlinks xarray iteration.
> 
> The reason is that devlink_register/unregister() functions taking
> devlink_mutex would deadlock during devlink reload operation of devlink
> instance which registers/unregisters nested devlink instances.
> 
> The devlinks xarray consistency is ensured internally by xarray.
> There is a reference taken when working with devlink using
> devlink_try_get(). But there is no guarantee that devlink pointer
> picked during xarray iteration is not freed before devlink_try_get()
> is called.
> 
> Make sure that devlink_try_get() works with valid pointer.
> Achieve it by:
> 1) Splitting devlink_put() so the completion is sent only
>    after grace period. Completion unblocks the devlink_unregister()
>    routine, which is followed-up by devlink_free()
> 2) During devlinks xa_array iteration, get devlink pointer from xa_array
>    holding RCU read lock and taking reference using devlink_try_get()
>    before unlock.
> 
> Signed-off-by: Jiri Pirko <jiri@nvidia.com>

Reviewed-by: Jakub Kicinski <kuba@kernel.org>
diff mbox series

Patch

diff --git a/net/core/devlink.c b/net/core/devlink.c
index 98d79feeb3dc..c7abd928f389 100644
--- a/net/core/devlink.c
+++ b/net/core/devlink.c
@@ -70,6 +70,7 @@  struct devlink {
 	u8 reload_failed:1;
 	refcount_t refcount;
 	struct completion comp;
+	struct rcu_head rcu;
 	char priv[] __aligned(NETDEV_ALIGN);
 };
 
@@ -221,8 +222,6 @@  static DEFINE_XARRAY_FLAGS(devlinks, XA_FLAGS_ALLOC);
 /* devlink_mutex
  *
  * An overall lock guarding every operation coming from userspace.
- * It also guards devlink devices list and it is taken when
- * driver registers/unregisters it.
  */
 static DEFINE_MUTEX(devlink_mutex);
 
@@ -232,10 +231,21 @@  struct net *devlink_net(const struct devlink *devlink)
 }
 EXPORT_SYMBOL_GPL(devlink_net);
 
+static void __devlink_put_rcu(struct rcu_head *head)
+{
+	struct devlink *devlink = container_of(head, struct devlink, rcu);
+
+	complete(&devlink->comp);
+}
+
 void devlink_put(struct devlink *devlink)
 {
 	if (refcount_dec_and_test(&devlink->refcount))
-		complete(&devlink->comp);
+		/* Make sure unregister operation that may await the completion
+		 * is unblocked only after all users are after the end of
+		 * RCU grace period.
+		 */
+		call_rcu(&devlink->rcu, __devlink_put_rcu);
 }
 
 struct devlink *__must_check devlink_try_get(struct devlink *devlink)
@@ -278,12 +288,55 @@  void devl_unlock(struct devlink *devlink)
 }
 EXPORT_SYMBOL_GPL(devl_unlock);
 
+static struct devlink *
+devlinks_xa_find_get(unsigned long *indexp, xa_mark_t filter,
+		     void * (*xa_find_fn)(struct xarray *, unsigned long *,
+					  unsigned long, xa_mark_t))
+{
+	struct devlink *devlink;
+
+	rcu_read_lock();
+retry:
+	devlink = xa_find_fn(&devlinks, indexp, ULONG_MAX, DEVLINK_REGISTERED);
+	if (!devlink)
+		goto unlock;
+	/* For a possible retry, the xa_find_after() should be always used */
+	xa_find_fn = xa_find_after;
+	if (!devlink_try_get(devlink))
+		goto retry;
+unlock:
+	rcu_read_unlock();
+	return devlink;
+}
+
+static struct devlink *devlinks_xa_find_get_first(unsigned long *indexp,
+						  xa_mark_t filter)
+{
+	return devlinks_xa_find_get(indexp, filter, xa_find);
+}
+
+static struct devlink *devlinks_xa_find_get_next(unsigned long *indexp,
+						 xa_mark_t filter)
+{
+	return devlinks_xa_find_get(indexp, filter, xa_find_after);
+}
+
+/* Iterate over devlink pointers which were possible to get reference to.
+ * devlink_put() needs to be called for each iterated devlink pointer
+ * in loop body in order to release the reference.
+ */
+#define devlinks_xa_for_each_get(index, devlink, filter)			\
+	for (index = 0, devlink = devlinks_xa_find_get_first(&index, filter);	\
+	     devlink; devlink = devlinks_xa_find_get_next(&index, filter))
+
+#define devlinks_xa_for_each_registered_get(index, devlink)			\
+	devlinks_xa_for_each_get(index, devlink, DEVLINK_REGISTERED)
+
 static struct devlink *devlink_get_from_attrs(struct net *net,
 					      struct nlattr **attrs)
 {
 	struct devlink *devlink;
 	unsigned long index;
-	bool found = false;
 	char *busname;
 	char *devname;
 
@@ -293,21 +346,15 @@  static struct devlink *devlink_get_from_attrs(struct net *net,
 	busname = nla_data(attrs[DEVLINK_ATTR_BUS_NAME]);
 	devname = nla_data(attrs[DEVLINK_ATTR_DEV_NAME]);
 
-	lockdep_assert_held(&devlink_mutex);
-
-	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
+	devlinks_xa_for_each_registered_get(index, devlink) {
 		if (strcmp(devlink->dev->bus->name, busname) == 0 &&
 		    strcmp(dev_name(devlink->dev), devname) == 0 &&
-		    net_eq(devlink_net(devlink), net)) {
-			found = true;
-			break;
-		}
+		    net_eq(devlink_net(devlink), net))
+			return devlink;
+		devlink_put(devlink);
 	}
 
-	if (!found || !devlink_try_get(devlink))
-		devlink = ERR_PTR(-ENODEV);
-
-	return devlink;
+	return ERR_PTR(-ENODEV);
 }
 
 static struct devlink_port *devlink_port_get_by_index(struct devlink *devlink,
@@ -1329,10 +1376,7 @@  static int devlink_nl_cmd_rate_get_dumpit(struct sk_buff *msg,
 	int err = 0;
 
 	mutex_lock(&devlink_mutex);
-	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
-		if (!devlink_try_get(devlink))
-			continue;
-
+	devlinks_xa_for_each_registered_get(index, devlink) {
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
 			goto retry;
 
@@ -1432,10 +1476,7 @@  static int devlink_nl_cmd_get_dumpit(struct sk_buff *msg,
 	int err;
 
 	mutex_lock(&devlink_mutex);
-	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
-		if (!devlink_try_get(devlink))
-			continue;
-
+	devlinks_xa_for_each_registered_get(index, devlink) {
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) {
 			devlink_put(devlink);
 			continue;
@@ -1495,10 +1536,7 @@  static int devlink_nl_cmd_port_get_dumpit(struct sk_buff *msg,
 	int err;
 
 	mutex_lock(&devlink_mutex);
-	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
-		if (!devlink_try_get(devlink))
-			continue;
-
+	devlinks_xa_for_each_registered_get(index, devlink) {
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
 			goto retry;
 
@@ -2177,10 +2215,7 @@  static int devlink_nl_cmd_linecard_get_dumpit(struct sk_buff *msg,
 	int err;
 
 	mutex_lock(&devlink_mutex);
-	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
-		if (!devlink_try_get(devlink))
-			continue;
-
+	devlinks_xa_for_each_registered_get(index, devlink) {
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
 			goto retry;
 
@@ -2449,10 +2484,7 @@  static int devlink_nl_cmd_sb_get_dumpit(struct sk_buff *msg,
 	int err;
 
 	mutex_lock(&devlink_mutex);
-	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
-		if (!devlink_try_get(devlink))
-			continue;
-
+	devlinks_xa_for_each_registered_get(index, devlink) {
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
 			goto retry;
 
@@ -2601,10 +2633,7 @@  static int devlink_nl_cmd_sb_pool_get_dumpit(struct sk_buff *msg,
 	int err = 0;
 
 	mutex_lock(&devlink_mutex);
-	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
-		if (!devlink_try_get(devlink))
-			continue;
-
+	devlinks_xa_for_each_registered_get(index, devlink) {
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) ||
 		    !devlink->ops->sb_pool_get)
 			goto retry;
@@ -2822,10 +2851,7 @@  static int devlink_nl_cmd_sb_port_pool_get_dumpit(struct sk_buff *msg,
 	int err = 0;
 
 	mutex_lock(&devlink_mutex);
-	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
-		if (!devlink_try_get(devlink))
-			continue;
-
+	devlinks_xa_for_each_registered_get(index, devlink) {
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) ||
 		    !devlink->ops->sb_port_pool_get)
 			goto retry;
@@ -3071,10 +3097,7 @@  devlink_nl_cmd_sb_tc_pool_bind_get_dumpit(struct sk_buff *msg,
 	int err = 0;
 
 	mutex_lock(&devlink_mutex);
-	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
-		if (!devlink_try_get(devlink))
-			continue;
-
+	devlinks_xa_for_each_registered_get(index, devlink) {
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) ||
 		    !devlink->ops->sb_tc_pool_bind_get)
 			goto retry;
@@ -5158,10 +5181,7 @@  static int devlink_nl_cmd_param_get_dumpit(struct sk_buff *msg,
 	int err = 0;
 
 	mutex_lock(&devlink_mutex);
-	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
-		if (!devlink_try_get(devlink))
-			continue;
-
+	devlinks_xa_for_each_registered_get(index, devlink) {
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
 			goto retry;
 
@@ -5393,10 +5413,7 @@  static int devlink_nl_cmd_port_param_get_dumpit(struct sk_buff *msg,
 	int err = 0;
 
 	mutex_lock(&devlink_mutex);
-	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
-		if (!devlink_try_get(devlink))
-			continue;
-
+	devlinks_xa_for_each_registered_get(index, devlink) {
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
 			goto retry;
 
@@ -5977,10 +5994,7 @@  static int devlink_nl_cmd_region_get_dumpit(struct sk_buff *msg,
 	int err = 0;
 
 	mutex_lock(&devlink_mutex);
-	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
-		if (!devlink_try_get(devlink))
-			continue;
-
+	devlinks_xa_for_each_registered_get(index, devlink) {
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
 			goto retry;
 
@@ -6511,10 +6525,7 @@  static int devlink_nl_cmd_info_get_dumpit(struct sk_buff *msg,
 	int err = 0;
 
 	mutex_lock(&devlink_mutex);
-	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
-		if (!devlink_try_get(devlink))
-			continue;
-
+	devlinks_xa_for_each_registered_get(index, devlink) {
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
 			goto retry;
 
@@ -7691,10 +7702,7 @@  devlink_nl_cmd_health_reporter_get_dumpit(struct sk_buff *msg,
 	int err;
 
 	mutex_lock(&devlink_mutex);
-	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
-		if (!devlink_try_get(devlink))
-			continue;
-
+	devlinks_xa_for_each_registered_get(index, devlink) {
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
 			goto retry_rep;
 
@@ -7721,10 +7729,7 @@  devlink_nl_cmd_health_reporter_get_dumpit(struct sk_buff *msg,
 		devlink_put(devlink);
 	}
 
-	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
-		if (!devlink_try_get(devlink))
-			continue;
-
+	devlinks_xa_for_each_registered_get(index, devlink) {
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
 			goto retry_port;
 
@@ -8291,10 +8296,7 @@  static int devlink_nl_cmd_trap_get_dumpit(struct sk_buff *msg,
 	int err;
 
 	mutex_lock(&devlink_mutex);
-	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
-		if (!devlink_try_get(devlink))
-			continue;
-
+	devlinks_xa_for_each_registered_get(index, devlink) {
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
 			goto retry;
 
@@ -8518,10 +8520,7 @@  static int devlink_nl_cmd_trap_group_get_dumpit(struct sk_buff *msg,
 	int err;
 
 	mutex_lock(&devlink_mutex);
-	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
-		if (!devlink_try_get(devlink))
-			continue;
-
+	devlinks_xa_for_each_registered_get(index, devlink) {
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
 			goto retry;
 
@@ -8832,10 +8831,7 @@  static int devlink_nl_cmd_trap_policer_get_dumpit(struct sk_buff *msg,
 	int err;
 
 	mutex_lock(&devlink_mutex);
-	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
-		if (!devlink_try_get(devlink))
-			continue;
-
+	devlinks_xa_for_each_registered_get(index, devlink) {
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
 			goto retry;
 
@@ -9589,10 +9585,8 @@  void devlink_register(struct devlink *devlink)
 	ASSERT_DEVLINK_NOT_REGISTERED(devlink);
 	/* Make sure that we are in .probe() routine */
 
-	mutex_lock(&devlink_mutex);
 	xa_set_mark(&devlinks, devlink->index, DEVLINK_REGISTERED);
 	devlink_notify_register(devlink);
-	mutex_unlock(&devlink_mutex);
 }
 EXPORT_SYMBOL_GPL(devlink_register);
 
@@ -9609,10 +9603,8 @@  void devlink_unregister(struct devlink *devlink)
 	devlink_put(devlink);
 	wait_for_completion(&devlink->comp);
 
-	mutex_lock(&devlink_mutex);
 	devlink_notify_unregister(devlink);
 	xa_clear_mark(&devlinks, devlink->index, DEVLINK_REGISTERED);
-	mutex_unlock(&devlink_mutex);
 }
 EXPORT_SYMBOL_GPL(devlink_unregister);
 
@@ -12281,10 +12273,7 @@  static void __net_exit devlink_pernet_pre_exit(struct net *net)
 	 * all devlink instances from this namespace into init_net.
 	 */
 	mutex_lock(&devlink_mutex);
-	xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
-		if (!devlink_try_get(devlink))
-			continue;
-
+	devlinks_xa_for_each_registered_get(index, devlink) {
 		if (!net_eq(devlink_net(devlink), net))
 			goto retry;