diff mbox series

[v3,03/10] fs: lockless mntns rbtree lookup

Message ID 20241213-work-mount-rbtree-lockless-v3-3-6e3cdaf9b280@kernel.org (mailing list archive)
State New
Headers show
Series fs: lockless mntns lookup | expand

Commit Message

Christian Brauner Dec. 12, 2024, 11:03 p.m. UTC
Currently we use a read-write lock but for the simple search case we can
make this lockless. Creating a new mount namespace is a rather rare
event compared with querying mounts in a foreign mount namespace. Once
this is picked up by e.g., systemd to list mounts in another mount in
it's isolated services or in containers this will be used a lot so this
seems worthwhile doing.

Signed-off-by: Christian Brauner <brauner@kernel.org>
---
 fs/mount.h     |   5 ++-
 fs/namespace.c | 119 +++++++++++++++++++++++++++++++++++----------------------
 2 files changed, 77 insertions(+), 47 deletions(-)

Comments

Peter Zijlstra Dec. 13, 2024, 8:50 a.m. UTC | #1
On Fri, Dec 13, 2024 at 12:03:42AM +0100, Christian Brauner wrote:
> +static inline void mnt_ns_tree_write_lock(void)
> +{
> +	write_lock(&mnt_ns_tree_lock);
> +	write_seqcount_begin(&mnt_ns_tree_seqcount);
> +}
> +
> +static inline void mnt_ns_tree_write_unlock(void)
> +{
> +	write_seqcount_end(&mnt_ns_tree_seqcount);
> +	write_unlock(&mnt_ns_tree_lock);
>  }

>  static void mnt_ns_tree_add(struct mnt_namespace *ns)
>  {
> -	guard(write_lock)(&mnt_ns_tree_lock);
> -	rb_add(&ns->mnt_ns_tree_node, &mnt_ns_tree, mnt_ns_less);
> +	struct rb_node *node;
> +
> +	mnt_ns_tree_write_lock();
> +	node = rb_find_add_rcu(&ns->mnt_ns_tree_node, &mnt_ns_tree, mnt_ns_cmp);
> +	mnt_ns_tree_write_unlock();
> +
> +	WARN_ON_ONCE(node);
>  }

>  static void mnt_ns_tree_remove(struct mnt_namespace *ns)
>  {
>  	/* remove from global mount namespace list */
>  	if (!is_anon_ns(ns)) {
> -		guard(write_lock)(&mnt_ns_tree_lock);
> +		mnt_ns_tree_write_lock();
>  		rb_erase(&ns->mnt_ns_tree_node, &mnt_ns_tree);
> +		mnt_ns_tree_write_unlock();
>  	}
>  
> -	mnt_ns_release(ns);
> +	call_rcu(&ns->mnt_ns_rcu, mnt_ns_release_rcu);
>  }

Its probably not worth the effort, but I figured I'd mention it anyway,
if you do:

DEFINE_LOCK_GUARD_0(mnt_ns_tree_lock, mnt_ns_tree_lock(), mnt_ns_tree_unlock())

You can use: guard(mnt_ns_tree_lock)();

But like said, probably not worth it, given the above are the only two
sites and they're utterly trivial.


Anyway, rest of the patches look good now.
Jeff Layton Dec. 13, 2024, 2:11 p.m. UTC | #2
On Fri, 2024-12-13 at 00:03 +0100, Christian Brauner wrote:
> Currently we use a read-write lock but for the simple search case we can
> make this lockless. Creating a new mount namespace is a rather rare
> event compared with querying mounts in a foreign mount namespace. Once
> this is picked up by e.g., systemd to list mounts in another mount in
> it's isolated services or in containers this will be used a lot so this
> seems worthwhile doing.
> 
> Signed-off-by: Christian Brauner <brauner@kernel.org>
> ---
>  fs/mount.h     |   5 ++-
>  fs/namespace.c | 119 +++++++++++++++++++++++++++++++++++----------------------
>  2 files changed, 77 insertions(+), 47 deletions(-)
> 
> diff --git a/fs/mount.h b/fs/mount.h
> index 185fc56afc13338f8185fe818051444d540cbd5b..36ead0e45e8aa7614c00001102563a711d9dae6e 100644
> --- a/fs/mount.h
> +++ b/fs/mount.h
> @@ -12,7 +12,10 @@ struct mnt_namespace {
>  	struct user_namespace	*user_ns;
>  	struct ucounts		*ucounts;
>  	u64			seq;	/* Sequence number to prevent loops */
> -	wait_queue_head_t poll;
> +	union {
> +		wait_queue_head_t	poll;
> +		struct rcu_head		mnt_ns_rcu;
> +	};
>  	u64 event;
>  	unsigned int		nr_mounts; /* # of mounts in the namespace */
>  	unsigned int		pending_mounts;
> diff --git a/fs/namespace.c b/fs/namespace.c
> index 10fa18dd66018fadfdc9d18c59a851eed7bd55ad..52adee787eb1b6ee8831705b2b121854c3370fb3 100644
> --- a/fs/namespace.c
> +++ b/fs/namespace.c
> @@ -79,6 +79,8 @@ static DECLARE_RWSEM(namespace_sem);
>  static HLIST_HEAD(unmounted);	/* protected by namespace_sem */
>  static LIST_HEAD(ex_mountpoints); /* protected by namespace_sem */
>  static DEFINE_RWLOCK(mnt_ns_tree_lock);
> +static seqcount_rwlock_t mnt_ns_tree_seqcount = SEQCNT_RWLOCK_ZERO(mnt_ns_tree_seqcount, &mnt_ns_tree_lock);
> +
>  static struct rb_root mnt_ns_tree = RB_ROOT; /* protected by mnt_ns_tree_lock */
>  
>  struct mount_kattr {
> @@ -105,17 +107,6 @@ EXPORT_SYMBOL_GPL(fs_kobj);
>   */
>  __cacheline_aligned_in_smp DEFINE_SEQLOCK(mount_lock);
>  
> -static int mnt_ns_cmp(u64 seq, const struct mnt_namespace *ns)
> -{
> -	u64 seq_b = ns->seq;
> -
> -	if (seq < seq_b)
> -		return -1;
> -	if (seq > seq_b)
> -		return 1;
> -	return 0;
> -}
> -
>  static inline struct mnt_namespace *node_to_mnt_ns(const struct rb_node *node)
>  {
>  	if (!node)
> @@ -123,19 +114,41 @@ static inline struct mnt_namespace *node_to_mnt_ns(const struct rb_node *node)
>  	return rb_entry(node, struct mnt_namespace, mnt_ns_tree_node);
>  }
>  
> -static bool mnt_ns_less(struct rb_node *a, const struct rb_node *b)
> +static int mnt_ns_cmp(struct rb_node *a, const struct rb_node *b)
>  {
>  	struct mnt_namespace *ns_a = node_to_mnt_ns(a);
>  	struct mnt_namespace *ns_b = node_to_mnt_ns(b);
>  	u64 seq_a = ns_a->seq;
> +	u64 seq_b = ns_b->seq;
> +
> +	if (seq_a < seq_b)
> +		return -1;
> +	if (seq_a > seq_b)
> +		return 1;
> +	return 0;
> +}
>  
> -	return mnt_ns_cmp(seq_a, ns_b) < 0;
> +static inline void mnt_ns_tree_write_lock(void)
> +{
> +	write_lock(&mnt_ns_tree_lock);
> +	write_seqcount_begin(&mnt_ns_tree_seqcount);
> +}
> +
> +static inline void mnt_ns_tree_write_unlock(void)
> +{
> +	write_seqcount_end(&mnt_ns_tree_seqcount);
> +	write_unlock(&mnt_ns_tree_lock);
>  }
>  
>  static void mnt_ns_tree_add(struct mnt_namespace *ns)
>  {
> -	guard(write_lock)(&mnt_ns_tree_lock);
> -	rb_add(&ns->mnt_ns_tree_node, &mnt_ns_tree, mnt_ns_less);
> +	struct rb_node *node;
> +
> +	mnt_ns_tree_write_lock();
> +	node = rb_find_add_rcu(&ns->mnt_ns_tree_node, &mnt_ns_tree, mnt_ns_cmp);
> +	mnt_ns_tree_write_unlock();
> +
> +	WARN_ON_ONCE(node);
>  }
>  
>  static void mnt_ns_release(struct mnt_namespace *ns)
> @@ -150,41 +163,36 @@ static void mnt_ns_release(struct mnt_namespace *ns)
>  }
>  DEFINE_FREE(mnt_ns_release, struct mnt_namespace *, if (_T) mnt_ns_release(_T))
>  
> +static void mnt_ns_release_rcu(struct rcu_head *rcu)
> +{
> +	struct mnt_namespace *mnt_ns;
> +
> +	mnt_ns = container_of(rcu, struct mnt_namespace, mnt_ns_rcu);
> +	mnt_ns_release(mnt_ns);
> +}
> +
>  static void mnt_ns_tree_remove(struct mnt_namespace *ns)
>  {
>  	/* remove from global mount namespace list */
>  	if (!is_anon_ns(ns)) {
> -		guard(write_lock)(&mnt_ns_tree_lock);
> +		mnt_ns_tree_write_lock();
>  		rb_erase(&ns->mnt_ns_tree_node, &mnt_ns_tree);
> +		mnt_ns_tree_write_unlock();
>  	}
>  
> -	mnt_ns_release(ns);
> +	call_rcu(&ns->mnt_ns_rcu, mnt_ns_release_rcu);
>  }
>  
> -/*
> - * Returns the mount namespace which either has the specified id, or has the
> - * next smallest id afer the specified one.
> - */
> -static struct mnt_namespace *mnt_ns_find_id_at(u64 mnt_ns_id)
> +static int mnt_ns_find(const void *key, const struct rb_node *node)
>  {
> -	struct rb_node *node = mnt_ns_tree.rb_node;
> -	struct mnt_namespace *ret = NULL;
> -
> -	lockdep_assert_held(&mnt_ns_tree_lock);
> -
> -	while (node) {
> -		struct mnt_namespace *n = node_to_mnt_ns(node);
> +	const u64 mnt_ns_id = *(u64 *)key;
> +	const struct mnt_namespace *ns = node_to_mnt_ns(node);
>  
> -		if (mnt_ns_id <= n->seq) {
> -			ret = node_to_mnt_ns(node);
> -			if (mnt_ns_id == n->seq)
> -				break;
> -			node = node->rb_left;
> -		} else {
> -			node = node->rb_right;
> -		}
> -	}
> -	return ret;
> +	if (mnt_ns_id < ns->seq)
> +		return -1;
> +	if (mnt_ns_id > ns->seq)
> +		return 1;
> +	return 0;
>  }
>  
>  /*
> @@ -194,18 +202,37 @@ static struct mnt_namespace *mnt_ns_find_id_at(u64 mnt_ns_id)
>   * namespace the @namespace_sem must first be acquired. If the namespace has
>   * already shut down before acquiring @namespace_sem, {list,stat}mount() will
>   * see that the mount rbtree of the namespace is empty.
> + *
> + * Note the lookup is lockless protected by a sequence counter. We only
> + * need to guard against false negatives as false positives aren't
> + * possible. So if we didn't find a mount namespace and the sequence
> + * counter has changed we need to retry. If the sequence counter is
> + * still the same we know the search actually failed.
>   */
>  static struct mnt_namespace *lookup_mnt_ns(u64 mnt_ns_id)
>  {
> -       struct mnt_namespace *ns;
> +	struct mnt_namespace *ns;
> +	struct rb_node *node;
> +	unsigned int seq;
> +
> +	guard(rcu)();
> +	do {
> +		seq = read_seqcount_begin(&mnt_ns_tree_seqcount);
> +		node = rb_find_rcu(&mnt_ns_id, &mnt_ns_tree, mnt_ns_find);
> +		if (node)
> +			break;
> +	} while (read_seqcount_retry(&mnt_ns_tree_seqcount, seq));
>  
> -       guard(read_lock)(&mnt_ns_tree_lock);
> -       ns = mnt_ns_find_id_at(mnt_ns_id);
> -       if (!ns || ns->seq != mnt_ns_id)
> -               return NULL;
> +	if (!node)
> +		return NULL;
>  
> -       refcount_inc(&ns->passive);
> -       return ns;
> +	/*
> +	 * The last reference count is put with RCU delay so we can
> +	 * unconditonally acquire a reference here.
> +	 */
> +	ns = node_to_mnt_ns(node);
> +	refcount_inc(&ns->passive);

I'm a little uneasy with the unconditional refcount_inc() here. It
seems quite possible that this could to a 0->1 transition here. You may
be right that that technically won't cause a problem with the rcu lock
held, but at the very least, that will cause a refcount warning to pop.

Maybe this should be a refcount_inc_not_zero() and then you just return
NULL if the increment doesn't occur?

> +	return ns;
>  }
>  
>  static inline void lock_mount_hash(void)
>
Christian Brauner Dec. 13, 2024, 6:44 p.m. UTC | #3
On Fri, Dec 13, 2024 at 09:11:41AM -0500, Jeff Layton wrote:
> On Fri, 2024-12-13 at 00:03 +0100, Christian Brauner wrote:
> > Currently we use a read-write lock but for the simple search case we can
> > make this lockless. Creating a new mount namespace is a rather rare
> > event compared with querying mounts in a foreign mount namespace. Once
> > this is picked up by e.g., systemd to list mounts in another mount in
> > it's isolated services or in containers this will be used a lot so this
> > seems worthwhile doing.
> > 
> > Signed-off-by: Christian Brauner <brauner@kernel.org>
> > ---
> >  fs/mount.h     |   5 ++-
> >  fs/namespace.c | 119 +++++++++++++++++++++++++++++++++++----------------------
> >  2 files changed, 77 insertions(+), 47 deletions(-)
> > 
> > diff --git a/fs/mount.h b/fs/mount.h
> > index 185fc56afc13338f8185fe818051444d540cbd5b..36ead0e45e8aa7614c00001102563a711d9dae6e 100644
> > --- a/fs/mount.h
> > +++ b/fs/mount.h
> > @@ -12,7 +12,10 @@ struct mnt_namespace {
> >  	struct user_namespace	*user_ns;
> >  	struct ucounts		*ucounts;
> >  	u64			seq;	/* Sequence number to prevent loops */
> > -	wait_queue_head_t poll;
> > +	union {
> > +		wait_queue_head_t	poll;
> > +		struct rcu_head		mnt_ns_rcu;
> > +	};
> >  	u64 event;
> >  	unsigned int		nr_mounts; /* # of mounts in the namespace */
> >  	unsigned int		pending_mounts;
> > diff --git a/fs/namespace.c b/fs/namespace.c
> > index 10fa18dd66018fadfdc9d18c59a851eed7bd55ad..52adee787eb1b6ee8831705b2b121854c3370fb3 100644
> > --- a/fs/namespace.c
> > +++ b/fs/namespace.c
> > @@ -79,6 +79,8 @@ static DECLARE_RWSEM(namespace_sem);
> >  static HLIST_HEAD(unmounted);	/* protected by namespace_sem */
> >  static LIST_HEAD(ex_mountpoints); /* protected by namespace_sem */
> >  static DEFINE_RWLOCK(mnt_ns_tree_lock);
> > +static seqcount_rwlock_t mnt_ns_tree_seqcount = SEQCNT_RWLOCK_ZERO(mnt_ns_tree_seqcount, &mnt_ns_tree_lock);
> > +
> >  static struct rb_root mnt_ns_tree = RB_ROOT; /* protected by mnt_ns_tree_lock */
> >  
> >  struct mount_kattr {
> > @@ -105,17 +107,6 @@ EXPORT_SYMBOL_GPL(fs_kobj);
> >   */
> >  __cacheline_aligned_in_smp DEFINE_SEQLOCK(mount_lock);
> >  
> > -static int mnt_ns_cmp(u64 seq, const struct mnt_namespace *ns)
> > -{
> > -	u64 seq_b = ns->seq;
> > -
> > -	if (seq < seq_b)
> > -		return -1;
> > -	if (seq > seq_b)
> > -		return 1;
> > -	return 0;
> > -}
> > -
> >  static inline struct mnt_namespace *node_to_mnt_ns(const struct rb_node *node)
> >  {
> >  	if (!node)
> > @@ -123,19 +114,41 @@ static inline struct mnt_namespace *node_to_mnt_ns(const struct rb_node *node)
> >  	return rb_entry(node, struct mnt_namespace, mnt_ns_tree_node);
> >  }
> >  
> > -static bool mnt_ns_less(struct rb_node *a, const struct rb_node *b)
> > +static int mnt_ns_cmp(struct rb_node *a, const struct rb_node *b)
> >  {
> >  	struct mnt_namespace *ns_a = node_to_mnt_ns(a);
> >  	struct mnt_namespace *ns_b = node_to_mnt_ns(b);
> >  	u64 seq_a = ns_a->seq;
> > +	u64 seq_b = ns_b->seq;
> > +
> > +	if (seq_a < seq_b)
> > +		return -1;
> > +	if (seq_a > seq_b)
> > +		return 1;
> > +	return 0;
> > +}
> >  
> > -	return mnt_ns_cmp(seq_a, ns_b) < 0;
> > +static inline void mnt_ns_tree_write_lock(void)
> > +{
> > +	write_lock(&mnt_ns_tree_lock);
> > +	write_seqcount_begin(&mnt_ns_tree_seqcount);
> > +}
> > +
> > +static inline void mnt_ns_tree_write_unlock(void)
> > +{
> > +	write_seqcount_end(&mnt_ns_tree_seqcount);
> > +	write_unlock(&mnt_ns_tree_lock);
> >  }
> >  
> >  static void mnt_ns_tree_add(struct mnt_namespace *ns)
> >  {
> > -	guard(write_lock)(&mnt_ns_tree_lock);
> > -	rb_add(&ns->mnt_ns_tree_node, &mnt_ns_tree, mnt_ns_less);
> > +	struct rb_node *node;
> > +
> > +	mnt_ns_tree_write_lock();
> > +	node = rb_find_add_rcu(&ns->mnt_ns_tree_node, &mnt_ns_tree, mnt_ns_cmp);
> > +	mnt_ns_tree_write_unlock();
> > +
> > +	WARN_ON_ONCE(node);
> >  }
> >  
> >  static void mnt_ns_release(struct mnt_namespace *ns)
> > @@ -150,41 +163,36 @@ static void mnt_ns_release(struct mnt_namespace *ns)
> >  }
> >  DEFINE_FREE(mnt_ns_release, struct mnt_namespace *, if (_T) mnt_ns_release(_T))
> >  
> > +static void mnt_ns_release_rcu(struct rcu_head *rcu)
> > +{
> > +	struct mnt_namespace *mnt_ns;
> > +
> > +	mnt_ns = container_of(rcu, struct mnt_namespace, mnt_ns_rcu);
> > +	mnt_ns_release(mnt_ns);
> > +}
> > +
> >  static void mnt_ns_tree_remove(struct mnt_namespace *ns)
> >  {
> >  	/* remove from global mount namespace list */
> >  	if (!is_anon_ns(ns)) {
> > -		guard(write_lock)(&mnt_ns_tree_lock);
> > +		mnt_ns_tree_write_lock();
> >  		rb_erase(&ns->mnt_ns_tree_node, &mnt_ns_tree);
> > +		mnt_ns_tree_write_unlock();
> >  	}
> >  
> > -	mnt_ns_release(ns);
> > +	call_rcu(&ns->mnt_ns_rcu, mnt_ns_release_rcu);
> >  }
> >  
> > -/*
> > - * Returns the mount namespace which either has the specified id, or has the
> > - * next smallest id afer the specified one.
> > - */
> > -static struct mnt_namespace *mnt_ns_find_id_at(u64 mnt_ns_id)
> > +static int mnt_ns_find(const void *key, const struct rb_node *node)
> >  {
> > -	struct rb_node *node = mnt_ns_tree.rb_node;
> > -	struct mnt_namespace *ret = NULL;
> > -
> > -	lockdep_assert_held(&mnt_ns_tree_lock);
> > -
> > -	while (node) {
> > -		struct mnt_namespace *n = node_to_mnt_ns(node);
> > +	const u64 mnt_ns_id = *(u64 *)key;
> > +	const struct mnt_namespace *ns = node_to_mnt_ns(node);
> >  
> > -		if (mnt_ns_id <= n->seq) {
> > -			ret = node_to_mnt_ns(node);
> > -			if (mnt_ns_id == n->seq)
> > -				break;
> > -			node = node->rb_left;
> > -		} else {
> > -			node = node->rb_right;
> > -		}
> > -	}
> > -	return ret;
> > +	if (mnt_ns_id < ns->seq)
> > +		return -1;
> > +	if (mnt_ns_id > ns->seq)
> > +		return 1;
> > +	return 0;
> >  }
> >  
> >  /*
> > @@ -194,18 +202,37 @@ static struct mnt_namespace *mnt_ns_find_id_at(u64 mnt_ns_id)
> >   * namespace the @namespace_sem must first be acquired. If the namespace has
> >   * already shut down before acquiring @namespace_sem, {list,stat}mount() will
> >   * see that the mount rbtree of the namespace is empty.
> > + *
> > + * Note the lookup is lockless protected by a sequence counter. We only
> > + * need to guard against false negatives as false positives aren't
> > + * possible. So if we didn't find a mount namespace and the sequence
> > + * counter has changed we need to retry. If the sequence counter is
> > + * still the same we know the search actually failed.
> >   */
> >  static struct mnt_namespace *lookup_mnt_ns(u64 mnt_ns_id)
> >  {
> > -       struct mnt_namespace *ns;
> > +	struct mnt_namespace *ns;
> > +	struct rb_node *node;
> > +	unsigned int seq;
> > +
> > +	guard(rcu)();
> > +	do {
> > +		seq = read_seqcount_begin(&mnt_ns_tree_seqcount);
> > +		node = rb_find_rcu(&mnt_ns_id, &mnt_ns_tree, mnt_ns_find);
> > +		if (node)
> > +			break;
> > +	} while (read_seqcount_retry(&mnt_ns_tree_seqcount, seq));
> >  
> > -       guard(read_lock)(&mnt_ns_tree_lock);
> > -       ns = mnt_ns_find_id_at(mnt_ns_id);
> > -       if (!ns || ns->seq != mnt_ns_id)
> > -               return NULL;
> > +	if (!node)
> > +		return NULL;
> >  
> > -       refcount_inc(&ns->passive);
> > -       return ns;
> > +	/*
> > +	 * The last reference count is put with RCU delay so we can
> > +	 * unconditonally acquire a reference here.
> > +	 */
> > +	ns = node_to_mnt_ns(node);
> > +	refcount_inc(&ns->passive);
> 
> I'm a little uneasy with the unconditional refcount_inc() here. It
> seems quite possible that this could to a 0->1 transition here. You may
> be right that that technically won't cause a problem with the rcu lock
> held, but at the very least, that will cause a refcount warning to pop.
> 
> Maybe this should be a refcount_inc_not_zero() and then you just return
> NULL if the increment doesn't occur?

So this shouldn't be possible (and Paul is on the thread and can tell me
if I'm wrong) because:

call_rcu(&ns->mnt_ns_rcu, mnt_ns_release_rcu);
-> mnt_ns_release_rcu()
   -> mnt_ns_release()
      -> refcount_dec_and_test()

Which means that decrements are RCU delayed. Any reader must walk the
list holding the RCU lock. If they find the mount namespace still on the
list then mnt_ns_release() will be deferred until they are done.

In order for what you describe to happen a reader must find that mount
namespace still in the list or rbtree and mnt_ns_release() is called
directly.

But afaict that doesn't happen. mnt_ns_release() is only called directly
when the mount namespace has never been on any of the lists.

refcount_inc() will already WARN() if the previous value was 0 and
splat. If we use refcount_inc_not_zero() we're adding additional memory
barriers when we simply don't need them.

But if that's important to you though than I'd rather switch the passive
count to atomic_t and use atomic_dec_and_test() in mnt_ns_release() and
then use similar logic as I used for file_ref_inc(): 

unsigned long prior = atomic_fetch_inc_relaxed(&ns->passive);
WARN_ONCE(!prior, "increment from zero UAF condition present");

This would give us the opportunity to catch a "shouldn't happen
condition". But using refcount_inc_not_zero() would e.g., confuse me a
few months down the line (Yes, we could probably add a comment.).

> 
> > +	return ns;
> >  }
> >  
> >  static inline void lock_mount_hash(void)
> > 
> 
> -- 
> Jeff Layton <jlayton@kernel.org>
Jeff Layton Dec. 13, 2024, 7:02 p.m. UTC | #4
On Fri, 2024-12-13 at 19:44 +0100, Christian Brauner wrote:
> On Fri, Dec 13, 2024 at 09:11:41AM -0500, Jeff Layton wrote:
> > On Fri, 2024-12-13 at 00:03 +0100, Christian Brauner wrote:
> > > Currently we use a read-write lock but for the simple search case we can
> > > make this lockless. Creating a new mount namespace is a rather rare
> > > event compared with querying mounts in a foreign mount namespace. Once
> > > this is picked up by e.g., systemd to list mounts in another mount in
> > > it's isolated services or in containers this will be used a lot so this
> > > seems worthwhile doing.
> > > 
> > > Signed-off-by: Christian Brauner <brauner@kernel.org>
> > > ---
> > >  fs/mount.h     |   5 ++-
> > >  fs/namespace.c | 119 +++++++++++++++++++++++++++++++++++----------------------
> > >  2 files changed, 77 insertions(+), 47 deletions(-)
> > > 
> > > diff --git a/fs/mount.h b/fs/mount.h
> > > index 185fc56afc13338f8185fe818051444d540cbd5b..36ead0e45e8aa7614c00001102563a711d9dae6e 100644
> > > --- a/fs/mount.h
> > > +++ b/fs/mount.h
> > > @@ -12,7 +12,10 @@ struct mnt_namespace {
> > >  	struct user_namespace	*user_ns;
> > >  	struct ucounts		*ucounts;
> > >  	u64			seq;	/* Sequence number to prevent loops */
> > > -	wait_queue_head_t poll;
> > > +	union {
> > > +		wait_queue_head_t	poll;
> > > +		struct rcu_head		mnt_ns_rcu;
> > > +	};
> > >  	u64 event;
> > >  	unsigned int		nr_mounts; /* # of mounts in the namespace */
> > >  	unsigned int		pending_mounts;
> > > diff --git a/fs/namespace.c b/fs/namespace.c
> > > index 10fa18dd66018fadfdc9d18c59a851eed7bd55ad..52adee787eb1b6ee8831705b2b121854c3370fb3 100644
> > > --- a/fs/namespace.c
> > > +++ b/fs/namespace.c
> > > @@ -79,6 +79,8 @@ static DECLARE_RWSEM(namespace_sem);
> > >  static HLIST_HEAD(unmounted);	/* protected by namespace_sem */
> > >  static LIST_HEAD(ex_mountpoints); /* protected by namespace_sem */
> > >  static DEFINE_RWLOCK(mnt_ns_tree_lock);
> > > +static seqcount_rwlock_t mnt_ns_tree_seqcount = SEQCNT_RWLOCK_ZERO(mnt_ns_tree_seqcount, &mnt_ns_tree_lock);
> > > +
> > >  static struct rb_root mnt_ns_tree = RB_ROOT; /* protected by mnt_ns_tree_lock */
> > >  
> > >  struct mount_kattr {
> > > @@ -105,17 +107,6 @@ EXPORT_SYMBOL_GPL(fs_kobj);
> > >   */
> > >  __cacheline_aligned_in_smp DEFINE_SEQLOCK(mount_lock);
> > >  
> > > -static int mnt_ns_cmp(u64 seq, const struct mnt_namespace *ns)
> > > -{
> > > -	u64 seq_b = ns->seq;
> > > -
> > > -	if (seq < seq_b)
> > > -		return -1;
> > > -	if (seq > seq_b)
> > > -		return 1;
> > > -	return 0;
> > > -}
> > > -
> > >  static inline struct mnt_namespace *node_to_mnt_ns(const struct rb_node *node)
> > >  {
> > >  	if (!node)
> > > @@ -123,19 +114,41 @@ static inline struct mnt_namespace *node_to_mnt_ns(const struct rb_node *node)
> > >  	return rb_entry(node, struct mnt_namespace, mnt_ns_tree_node);
> > >  }
> > >  
> > > -static bool mnt_ns_less(struct rb_node *a, const struct rb_node *b)
> > > +static int mnt_ns_cmp(struct rb_node *a, const struct rb_node *b)
> > >  {
> > >  	struct mnt_namespace *ns_a = node_to_mnt_ns(a);
> > >  	struct mnt_namespace *ns_b = node_to_mnt_ns(b);
> > >  	u64 seq_a = ns_a->seq;
> > > +	u64 seq_b = ns_b->seq;
> > > +
> > > +	if (seq_a < seq_b)
> > > +		return -1;
> > > +	if (seq_a > seq_b)
> > > +		return 1;
> > > +	return 0;
> > > +}
> > >  
> > > -	return mnt_ns_cmp(seq_a, ns_b) < 0;
> > > +static inline void mnt_ns_tree_write_lock(void)
> > > +{
> > > +	write_lock(&mnt_ns_tree_lock);
> > > +	write_seqcount_begin(&mnt_ns_tree_seqcount);
> > > +}
> > > +
> > > +static inline void mnt_ns_tree_write_unlock(void)
> > > +{
> > > +	write_seqcount_end(&mnt_ns_tree_seqcount);
> > > +	write_unlock(&mnt_ns_tree_lock);
> > >  }
> > >  
> > >  static void mnt_ns_tree_add(struct mnt_namespace *ns)
> > >  {
> > > -	guard(write_lock)(&mnt_ns_tree_lock);
> > > -	rb_add(&ns->mnt_ns_tree_node, &mnt_ns_tree, mnt_ns_less);
> > > +	struct rb_node *node;
> > > +
> > > +	mnt_ns_tree_write_lock();
> > > +	node = rb_find_add_rcu(&ns->mnt_ns_tree_node, &mnt_ns_tree, mnt_ns_cmp);
> > > +	mnt_ns_tree_write_unlock();
> > > +
> > > +	WARN_ON_ONCE(node);
> > >  }
> > >  
> > >  static void mnt_ns_release(struct mnt_namespace *ns)
> > > @@ -150,41 +163,36 @@ static void mnt_ns_release(struct mnt_namespace *ns)
> > >  }
> > >  DEFINE_FREE(mnt_ns_release, struct mnt_namespace *, if (_T) mnt_ns_release(_T))
> > >  
> > > +static void mnt_ns_release_rcu(struct rcu_head *rcu)
> > > +{
> > > +	struct mnt_namespace *mnt_ns;
> > > +
> > > +	mnt_ns = container_of(rcu, struct mnt_namespace, mnt_ns_rcu);
> > > +	mnt_ns_release(mnt_ns);
> > > +}
> > > +
> > >  static void mnt_ns_tree_remove(struct mnt_namespace *ns)
> > >  {
> > >  	/* remove from global mount namespace list */
> > >  	if (!is_anon_ns(ns)) {
> > > -		guard(write_lock)(&mnt_ns_tree_lock);
> > > +		mnt_ns_tree_write_lock();
> > >  		rb_erase(&ns->mnt_ns_tree_node, &mnt_ns_tree);
> > > +		mnt_ns_tree_write_unlock();
> > >  	}
> > >  
> > > -	mnt_ns_release(ns);
> > > +	call_rcu(&ns->mnt_ns_rcu, mnt_ns_release_rcu);
> > >  }
> > >  
> > > -/*
> > > - * Returns the mount namespace which either has the specified id, or has the
> > > - * next smallest id afer the specified one.
> > > - */
> > > -static struct mnt_namespace *mnt_ns_find_id_at(u64 mnt_ns_id)
> > > +static int mnt_ns_find(const void *key, const struct rb_node *node)
> > >  {
> > > -	struct rb_node *node = mnt_ns_tree.rb_node;
> > > -	struct mnt_namespace *ret = NULL;
> > > -
> > > -	lockdep_assert_held(&mnt_ns_tree_lock);
> > > -
> > > -	while (node) {
> > > -		struct mnt_namespace *n = node_to_mnt_ns(node);
> > > +	const u64 mnt_ns_id = *(u64 *)key;
> > > +	const struct mnt_namespace *ns = node_to_mnt_ns(node);
> > >  
> > > -		if (mnt_ns_id <= n->seq) {
> > > -			ret = node_to_mnt_ns(node);
> > > -			if (mnt_ns_id == n->seq)
> > > -				break;
> > > -			node = node->rb_left;
> > > -		} else {
> > > -			node = node->rb_right;
> > > -		}
> > > -	}
> > > -	return ret;
> > > +	if (mnt_ns_id < ns->seq)
> > > +		return -1;
> > > +	if (mnt_ns_id > ns->seq)
> > > +		return 1;
> > > +	return 0;
> > >  }
> > >  
> > >  /*
> > > @@ -194,18 +202,37 @@ static struct mnt_namespace *mnt_ns_find_id_at(u64 mnt_ns_id)
> > >   * namespace the @namespace_sem must first be acquired. If the namespace has
> > >   * already shut down before acquiring @namespace_sem, {list,stat}mount() will
> > >   * see that the mount rbtree of the namespace is empty.
> > > + *
> > > + * Note the lookup is lockless protected by a sequence counter. We only
> > > + * need to guard against false negatives as false positives aren't
> > > + * possible. So if we didn't find a mount namespace and the sequence
> > > + * counter has changed we need to retry. If the sequence counter is
> > > + * still the same we know the search actually failed.
> > >   */
> > >  static struct mnt_namespace *lookup_mnt_ns(u64 mnt_ns_id)
> > >  {
> > > -       struct mnt_namespace *ns;
> > > +	struct mnt_namespace *ns;
> > > +	struct rb_node *node;
> > > +	unsigned int seq;
> > > +
> > > +	guard(rcu)();
> > > +	do {
> > > +		seq = read_seqcount_begin(&mnt_ns_tree_seqcount);
> > > +		node = rb_find_rcu(&mnt_ns_id, &mnt_ns_tree, mnt_ns_find);
> > > +		if (node)
> > > +			break;
> > > +	} while (read_seqcount_retry(&mnt_ns_tree_seqcount, seq));
> > >  
> > > -       guard(read_lock)(&mnt_ns_tree_lock);
> > > -       ns = mnt_ns_find_id_at(mnt_ns_id);
> > > -       if (!ns || ns->seq != mnt_ns_id)
> > > -               return NULL;
> > > +	if (!node)
> > > +		return NULL;
> > >  
> > > -       refcount_inc(&ns->passive);
> > > -       return ns;
> > > +	/*
> > > +	 * The last reference count is put with RCU delay so we can
> > > +	 * unconditonally acquire a reference here.
> > > +	 */
> > > +	ns = node_to_mnt_ns(node);
> > > +	refcount_inc(&ns->passive);
> > 
> > I'm a little uneasy with the unconditional refcount_inc() here. It
> > seems quite possible that this could to a 0->1 transition here. You may
> > be right that that technically won't cause a problem with the rcu lock
> > held, but at the very least, that will cause a refcount warning to pop.
> > 
> > Maybe this should be a refcount_inc_not_zero() and then you just return
> > NULL if the increment doesn't occur?
> 
> So this shouldn't be possible (and Paul is on the thread and can tell me
> if I'm wrong) because:
> 
> call_rcu(&ns->mnt_ns_rcu, mnt_ns_release_rcu);
> -> mnt_ns_release_rcu()
>    -> mnt_ns_release()
>       -> refcount_dec_and_test()
> 
> Which means that decrements are RCU delayed. Any reader must walk the
> list holding the RCU lock. If they find the mount namespace still on the
> list then mnt_ns_release() will be deferred until they are done.
> 
>
> In order for what you describe to happen a reader must find that mount
> namespace still in the list or rbtree and mnt_ns_release() is called
> directly.
> 
> But afaict that doesn't happen. mnt_ns_release() is only called directly
> when the mount namespace has never been on any of the lists.
> 

> refcount_inc() will already WARN() if the previous value was 0 and
> splat. If we use refcount_inc_not_zero() we're adding additional memory
> barriers when we simply don't need them.

Ahh ok, I missed the detail that the refcount decrement was being done
in the RCU callback. With that, I guess that can't ever happen. Sorry
for the noise!

> But if that's important to you though than I'd rather switch the passive
> count to atomic_t and use atomic_dec_and_test() in mnt_ns_release() and
> then use similar logic as I used for file_ref_inc(): 
> 
> unsigned long prior = atomic_fetch_inc_relaxed(&ns->passive);
> WARN_ONCE(!prior, "increment from zero UAF condition present");
> 
> This would give us the opportunity to catch a "shouldn't happen
> condition". But using refcount_inc_not_zero() would e.g., confuse me a
> few months down the line (Yes, we could probably add a comment.).
> 

I don't feel strongly about it, but a comment explaining why you can
unconditionally bump it there would be helpful.

> > 
> > > +	return ns;
> > >  }
> > >  
> > >  static inline void lock_mount_hash(void)
> > >
diff mbox series

Patch

diff --git a/fs/mount.h b/fs/mount.h
index 185fc56afc13338f8185fe818051444d540cbd5b..36ead0e45e8aa7614c00001102563a711d9dae6e 100644
--- a/fs/mount.h
+++ b/fs/mount.h
@@ -12,7 +12,10 @@  struct mnt_namespace {
 	struct user_namespace	*user_ns;
 	struct ucounts		*ucounts;
 	u64			seq;	/* Sequence number to prevent loops */
-	wait_queue_head_t poll;
+	union {
+		wait_queue_head_t	poll;
+		struct rcu_head		mnt_ns_rcu;
+	};
 	u64 event;
 	unsigned int		nr_mounts; /* # of mounts in the namespace */
 	unsigned int		pending_mounts;
diff --git a/fs/namespace.c b/fs/namespace.c
index 10fa18dd66018fadfdc9d18c59a851eed7bd55ad..52adee787eb1b6ee8831705b2b121854c3370fb3 100644
--- a/fs/namespace.c
+++ b/fs/namespace.c
@@ -79,6 +79,8 @@  static DECLARE_RWSEM(namespace_sem);
 static HLIST_HEAD(unmounted);	/* protected by namespace_sem */
 static LIST_HEAD(ex_mountpoints); /* protected by namespace_sem */
 static DEFINE_RWLOCK(mnt_ns_tree_lock);
+static seqcount_rwlock_t mnt_ns_tree_seqcount = SEQCNT_RWLOCK_ZERO(mnt_ns_tree_seqcount, &mnt_ns_tree_lock);
+
 static struct rb_root mnt_ns_tree = RB_ROOT; /* protected by mnt_ns_tree_lock */
 
 struct mount_kattr {
@@ -105,17 +107,6 @@  EXPORT_SYMBOL_GPL(fs_kobj);
  */
 __cacheline_aligned_in_smp DEFINE_SEQLOCK(mount_lock);
 
-static int mnt_ns_cmp(u64 seq, const struct mnt_namespace *ns)
-{
-	u64 seq_b = ns->seq;
-
-	if (seq < seq_b)
-		return -1;
-	if (seq > seq_b)
-		return 1;
-	return 0;
-}
-
 static inline struct mnt_namespace *node_to_mnt_ns(const struct rb_node *node)
 {
 	if (!node)
@@ -123,19 +114,41 @@  static inline struct mnt_namespace *node_to_mnt_ns(const struct rb_node *node)
 	return rb_entry(node, struct mnt_namespace, mnt_ns_tree_node);
 }
 
-static bool mnt_ns_less(struct rb_node *a, const struct rb_node *b)
+static int mnt_ns_cmp(struct rb_node *a, const struct rb_node *b)
 {
 	struct mnt_namespace *ns_a = node_to_mnt_ns(a);
 	struct mnt_namespace *ns_b = node_to_mnt_ns(b);
 	u64 seq_a = ns_a->seq;
+	u64 seq_b = ns_b->seq;
+
+	if (seq_a < seq_b)
+		return -1;
+	if (seq_a > seq_b)
+		return 1;
+	return 0;
+}
 
-	return mnt_ns_cmp(seq_a, ns_b) < 0;
+static inline void mnt_ns_tree_write_lock(void)
+{
+	write_lock(&mnt_ns_tree_lock);
+	write_seqcount_begin(&mnt_ns_tree_seqcount);
+}
+
+static inline void mnt_ns_tree_write_unlock(void)
+{
+	write_seqcount_end(&mnt_ns_tree_seqcount);
+	write_unlock(&mnt_ns_tree_lock);
 }
 
 static void mnt_ns_tree_add(struct mnt_namespace *ns)
 {
-	guard(write_lock)(&mnt_ns_tree_lock);
-	rb_add(&ns->mnt_ns_tree_node, &mnt_ns_tree, mnt_ns_less);
+	struct rb_node *node;
+
+	mnt_ns_tree_write_lock();
+	node = rb_find_add_rcu(&ns->mnt_ns_tree_node, &mnt_ns_tree, mnt_ns_cmp);
+	mnt_ns_tree_write_unlock();
+
+	WARN_ON_ONCE(node);
 }
 
 static void mnt_ns_release(struct mnt_namespace *ns)
@@ -150,41 +163,36 @@  static void mnt_ns_release(struct mnt_namespace *ns)
 }
 DEFINE_FREE(mnt_ns_release, struct mnt_namespace *, if (_T) mnt_ns_release(_T))
 
+static void mnt_ns_release_rcu(struct rcu_head *rcu)
+{
+	struct mnt_namespace *mnt_ns;
+
+	mnt_ns = container_of(rcu, struct mnt_namespace, mnt_ns_rcu);
+	mnt_ns_release(mnt_ns);
+}
+
 static void mnt_ns_tree_remove(struct mnt_namespace *ns)
 {
 	/* remove from global mount namespace list */
 	if (!is_anon_ns(ns)) {
-		guard(write_lock)(&mnt_ns_tree_lock);
+		mnt_ns_tree_write_lock();
 		rb_erase(&ns->mnt_ns_tree_node, &mnt_ns_tree);
+		mnt_ns_tree_write_unlock();
 	}
 
-	mnt_ns_release(ns);
+	call_rcu(&ns->mnt_ns_rcu, mnt_ns_release_rcu);
 }
 
-/*
- * Returns the mount namespace which either has the specified id, or has the
- * next smallest id afer the specified one.
- */
-static struct mnt_namespace *mnt_ns_find_id_at(u64 mnt_ns_id)
+static int mnt_ns_find(const void *key, const struct rb_node *node)
 {
-	struct rb_node *node = mnt_ns_tree.rb_node;
-	struct mnt_namespace *ret = NULL;
-
-	lockdep_assert_held(&mnt_ns_tree_lock);
-
-	while (node) {
-		struct mnt_namespace *n = node_to_mnt_ns(node);
+	const u64 mnt_ns_id = *(u64 *)key;
+	const struct mnt_namespace *ns = node_to_mnt_ns(node);
 
-		if (mnt_ns_id <= n->seq) {
-			ret = node_to_mnt_ns(node);
-			if (mnt_ns_id == n->seq)
-				break;
-			node = node->rb_left;
-		} else {
-			node = node->rb_right;
-		}
-	}
-	return ret;
+	if (mnt_ns_id < ns->seq)
+		return -1;
+	if (mnt_ns_id > ns->seq)
+		return 1;
+	return 0;
 }
 
 /*
@@ -194,18 +202,37 @@  static struct mnt_namespace *mnt_ns_find_id_at(u64 mnt_ns_id)
  * namespace the @namespace_sem must first be acquired. If the namespace has
  * already shut down before acquiring @namespace_sem, {list,stat}mount() will
  * see that the mount rbtree of the namespace is empty.
+ *
+ * Note the lookup is lockless protected by a sequence counter. We only
+ * need to guard against false negatives as false positives aren't
+ * possible. So if we didn't find a mount namespace and the sequence
+ * counter has changed we need to retry. If the sequence counter is
+ * still the same we know the search actually failed.
  */
 static struct mnt_namespace *lookup_mnt_ns(u64 mnt_ns_id)
 {
-       struct mnt_namespace *ns;
+	struct mnt_namespace *ns;
+	struct rb_node *node;
+	unsigned int seq;
+
+	guard(rcu)();
+	do {
+		seq = read_seqcount_begin(&mnt_ns_tree_seqcount);
+		node = rb_find_rcu(&mnt_ns_id, &mnt_ns_tree, mnt_ns_find);
+		if (node)
+			break;
+	} while (read_seqcount_retry(&mnt_ns_tree_seqcount, seq));
 
-       guard(read_lock)(&mnt_ns_tree_lock);
-       ns = mnt_ns_find_id_at(mnt_ns_id);
-       if (!ns || ns->seq != mnt_ns_id)
-               return NULL;
+	if (!node)
+		return NULL;
 
-       refcount_inc(&ns->passive);
-       return ns;
+	/*
+	 * The last reference count is put with RCU delay so we can
+	 * unconditonally acquire a reference here.
+	 */
+	ns = node_to_mnt_ns(node);
+	refcount_inc(&ns->passive);
+	return ns;
 }
 
 static inline void lock_mount_hash(void)