diff mbox series

[RFC,11/30] mm: introduce slabobj_ext to support slab object extensions

Message ID 20220830214919.53220-12-surenb@google.com (mailing list archive)
State New
Headers show
Series Code tagging framework and applications | expand

Commit Message

Suren Baghdasaryan Aug. 30, 2022, 9:49 p.m. UTC
Currently slab pages can store only vectors of obj_cgroup pointers in
page->memcg_data. Introduce slabobj_ext structure to allow more data
to be stored for each slab object. Wraps obj_cgroup into slabobj_ext
to support current functionality while allowing to extend slabobj_ext
in the future.

Note: ideally the config dependency should be turned the other way around:
MEMCG should depend on SLAB_OBJ_EXT and {page|slab|folio}.memcg_data would
be renamed to something like {page|slab|folio}.objext_data. However doing
this in RFC would introduce considerable churn unrelated to the overall
idea, so avoiding this until v1.

Signed-off-by: Suren Baghdasaryan <surenb@google.com>
---
 include/linux/memcontrol.h |  18 ++++--
 init/Kconfig               |   5 ++
 mm/kfence/core.c           |   2 +-
 mm/memcontrol.c            |  60 ++++++++++---------
 mm/page_owner.c            |   2 +-
 mm/slab.h                  | 119 +++++++++++++++++++++++++------------
 6 files changed, 131 insertions(+), 75 deletions(-)

Comments

Roman Gushchin Sept. 1, 2022, 11:35 p.m. UTC | #1
On Tue, Aug 30, 2022 at 02:49:00PM -0700, Suren Baghdasaryan wrote:
> Currently slab pages can store only vectors of obj_cgroup pointers in
> page->memcg_data. Introduce slabobj_ext structure to allow more data
> to be stored for each slab object. Wraps obj_cgroup into slabobj_ext
> to support current functionality while allowing to extend slabobj_ext
> in the future.
> 
> Note: ideally the config dependency should be turned the other way around:
> MEMCG should depend on SLAB_OBJ_EXT and {page|slab|folio}.memcg_data would
> be renamed to something like {page|slab|folio}.objext_data. However doing
> this in RFC would introduce considerable churn unrelated to the overall
> idea, so avoiding this until v1.

Hi Suren!

I'd say CONFIG_MEMCG_KMEM and CONFIG_YOUR_NEW_STUFF should both depend on
SLAB_OBJ_EXT.
CONFIG_MEMCG_KMEM depend on CONFIG_MEMCG anyway.

> 
> Signed-off-by: Suren Baghdasaryan <surenb@google.com>
> ---
>  include/linux/memcontrol.h |  18 ++++--
>  init/Kconfig               |   5 ++
>  mm/kfence/core.c           |   2 +-
>  mm/memcontrol.c            |  60 ++++++++++---------
>  mm/page_owner.c            |   2 +-
>  mm/slab.h                  | 119 +++++++++++++++++++++++++------------
>  6 files changed, 131 insertions(+), 75 deletions(-)
> 
> diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h
> index 6257867fbf95..315399f77173 100644
> --- a/include/linux/memcontrol.h
> +++ b/include/linux/memcontrol.h
> @@ -227,6 +227,14 @@ struct obj_cgroup {
>  	};
>  };
>  
> +/*
> + * Extended information for slab objects stored as an array in page->memcg_data
> + * if MEMCG_DATA_OBJEXTS is set.
> + */
> +struct slabobj_ext {
> +	struct obj_cgroup *objcg;
> +} __aligned(8);

Why do we need this aligment requirement?

> +
>  /*
>   * The memory controller data structure. The memory controller controls both
>   * page cache and RSS per cgroup. We would eventually like to provide
> @@ -363,7 +371,7 @@ extern struct mem_cgroup *root_mem_cgroup;
>  
>  enum page_memcg_data_flags {
>  	/* page->memcg_data is a pointer to an objcgs vector */
> -	MEMCG_DATA_OBJCGS = (1UL << 0),
> +	MEMCG_DATA_OBJEXTS = (1UL << 0),
>  	/* page has been accounted as a non-slab kernel page */
>  	MEMCG_DATA_KMEM = (1UL << 1),
>  	/* the next bit after the last actual flag */
> @@ -401,7 +409,7 @@ static inline struct mem_cgroup *__folio_memcg(struct folio *folio)
>  	unsigned long memcg_data = folio->memcg_data;
>  
>  	VM_BUG_ON_FOLIO(folio_test_slab(folio), folio);
> -	VM_BUG_ON_FOLIO(memcg_data & MEMCG_DATA_OBJCGS, folio);
> +	VM_BUG_ON_FOLIO(memcg_data & MEMCG_DATA_OBJEXTS, folio);
>  	VM_BUG_ON_FOLIO(memcg_data & MEMCG_DATA_KMEM, folio);
>  
>  	return (struct mem_cgroup *)(memcg_data & ~MEMCG_DATA_FLAGS_MASK);
> @@ -422,7 +430,7 @@ static inline struct obj_cgroup *__folio_objcg(struct folio *folio)
>  	unsigned long memcg_data = folio->memcg_data;
>  
>  	VM_BUG_ON_FOLIO(folio_test_slab(folio), folio);
> -	VM_BUG_ON_FOLIO(memcg_data & MEMCG_DATA_OBJCGS, folio);
> +	VM_BUG_ON_FOLIO(memcg_data & MEMCG_DATA_OBJEXTS, folio);
>  	VM_BUG_ON_FOLIO(!(memcg_data & MEMCG_DATA_KMEM), folio);
>  
>  	return (struct obj_cgroup *)(memcg_data & ~MEMCG_DATA_FLAGS_MASK);
> @@ -517,7 +525,7 @@ static inline struct mem_cgroup *page_memcg_check(struct page *page)
>  	 */
>  	unsigned long memcg_data = READ_ONCE(page->memcg_data);
>  
> -	if (memcg_data & MEMCG_DATA_OBJCGS)
> +	if (memcg_data & MEMCG_DATA_OBJEXTS)
>  		return NULL;
>  
>  	if (memcg_data & MEMCG_DATA_KMEM) {
> @@ -556,7 +564,7 @@ static inline struct mem_cgroup *get_mem_cgroup_from_objcg(struct obj_cgroup *ob
>  static inline bool folio_memcg_kmem(struct folio *folio)
>  {
>  	VM_BUG_ON_PGFLAGS(PageTail(&folio->page), &folio->page);
> -	VM_BUG_ON_FOLIO(folio->memcg_data & MEMCG_DATA_OBJCGS, folio);
> +	VM_BUG_ON_FOLIO(folio->memcg_data & MEMCG_DATA_OBJEXTS, folio);
>  	return folio->memcg_data & MEMCG_DATA_KMEM;
>  }
>  
> diff --git a/init/Kconfig b/init/Kconfig
> index 532362fcfe31..82396d7a2717 100644
> --- a/init/Kconfig
> +++ b/init/Kconfig
> @@ -958,6 +958,10 @@ config MEMCG
>  	help
>  	  Provides control over the memory footprint of tasks in a cgroup.
>  
> +config SLAB_OBJ_EXT
> +	bool
> +	depends on MEMCG
> +
>  config MEMCG_SWAP
>  	bool
>  	depends on MEMCG && SWAP
> @@ -966,6 +970,7 @@ config MEMCG_SWAP
>  config MEMCG_KMEM
>  	bool
>  	depends on MEMCG && !SLOB
> +	select SLAB_OBJ_EXT
>  	default y
>  
>  config BLK_CGROUP
> diff --git a/mm/kfence/core.c b/mm/kfence/core.c
> index c252081b11df..c0958e4a32e2 100644
> --- a/mm/kfence/core.c
> +++ b/mm/kfence/core.c
> @@ -569,7 +569,7 @@ static unsigned long kfence_init_pool(void)
>  		__folio_set_slab(slab_folio(slab));
>  #ifdef CONFIG_MEMCG
>  		slab->memcg_data = (unsigned long)&kfence_metadata[i / 2 - 1].objcg |
> -				   MEMCG_DATA_OBJCGS;
> +				   MEMCG_DATA_OBJEXTS;
>  #endif
>  	}
>  
> diff --git a/mm/memcontrol.c b/mm/memcontrol.c
> index b69979c9ced5..3f407ef2f3f1 100644
> --- a/mm/memcontrol.c
> +++ b/mm/memcontrol.c
> @@ -2793,7 +2793,7 @@ static void commit_charge(struct folio *folio, struct mem_cgroup *memcg)
>  	folio->memcg_data = (unsigned long)memcg;
>  }
>  
> -#ifdef CONFIG_MEMCG_KMEM
> +#ifdef CONFIG_SLAB_OBJ_EXT
>  /*
>   * The allocated objcg pointers array is not accounted directly.
>   * Moreover, it should not come from DMA buffer and is not readily
> @@ -2801,38 +2801,20 @@ static void commit_charge(struct folio *folio, struct mem_cgroup *memcg)
>   */
>  #define OBJCGS_CLEAR_MASK	(__GFP_DMA | __GFP_RECLAIMABLE | __GFP_ACCOUNT)
>  
> -/*
> - * mod_objcg_mlstate() may be called with irq enabled, so
> - * mod_memcg_lruvec_state() should be used.
> - */
> -static inline void mod_objcg_mlstate(struct obj_cgroup *objcg,
> -				     struct pglist_data *pgdat,
> -				     enum node_stat_item idx, int nr)
> -{
> -	struct mem_cgroup *memcg;
> -	struct lruvec *lruvec;
> -
> -	rcu_read_lock();
> -	memcg = obj_cgroup_memcg(objcg);
> -	lruvec = mem_cgroup_lruvec(memcg, pgdat);
> -	mod_memcg_lruvec_state(lruvec, idx, nr);
> -	rcu_read_unlock();
> -}
> -
> -int memcg_alloc_slab_cgroups(struct slab *slab, struct kmem_cache *s,
> -				 gfp_t gfp, bool new_slab)
> +int alloc_slab_obj_exts(struct slab *slab, struct kmem_cache *s,
> +			gfp_t gfp, bool new_slab)
>  {
>  	unsigned int objects = objs_per_slab(s, slab);
>  	unsigned long memcg_data;
>  	void *vec;
>  
>  	gfp &= ~OBJCGS_CLEAR_MASK;
> -	vec = kcalloc_node(objects, sizeof(struct obj_cgroup *), gfp,
> +	vec = kcalloc_node(objects, sizeof(struct slabobj_ext), gfp,
>  			   slab_nid(slab));
>  	if (!vec)
>  		return -ENOMEM;
>  
> -	memcg_data = (unsigned long) vec | MEMCG_DATA_OBJCGS;
> +	memcg_data = (unsigned long) vec | MEMCG_DATA_OBJEXTS;
>  	if (new_slab) {
>  		/*
>  		 * If the slab is brand new and nobody can yet access its
> @@ -2843,7 +2825,7 @@ int memcg_alloc_slab_cgroups(struct slab *slab, struct kmem_cache *s,
>  	} else if (cmpxchg(&slab->memcg_data, 0, memcg_data)) {
>  		/*
>  		 * If the slab is already in use, somebody can allocate and
> -		 * assign obj_cgroups in parallel. In this case the existing
> +		 * assign slabobj_exts in parallel. In this case the existing
>  		 * objcg vector should be reused.
>  		 */
>  		kfree(vec);
> @@ -2853,6 +2835,26 @@ int memcg_alloc_slab_cgroups(struct slab *slab, struct kmem_cache *s,
>  	kmemleak_not_leak(vec);
>  	return 0;
>  }
> +#endif /* CONFIG_SLAB_OBJ_EXT */
> +
> +#ifdef CONFIG_MEMCG_KMEM
> +/*
> + * mod_objcg_mlstate() may be called with irq enabled, so
> + * mod_memcg_lruvec_state() should be used.
> + */
> +static inline void mod_objcg_mlstate(struct obj_cgroup *objcg,
> +				     struct pglist_data *pgdat,
> +				     enum node_stat_item idx, int nr)
> +{
> +	struct mem_cgroup *memcg;
> +	struct lruvec *lruvec;
> +
> +	rcu_read_lock();
> +	memcg = obj_cgroup_memcg(objcg);
> +	lruvec = mem_cgroup_lruvec(memcg, pgdat);
> +	mod_memcg_lruvec_state(lruvec, idx, nr);
> +	rcu_read_unlock();
> +}
>  
>  static __always_inline
>  struct mem_cgroup *mem_cgroup_from_obj_folio(struct folio *folio, void *p)
> @@ -2863,18 +2865,18 @@ struct mem_cgroup *mem_cgroup_from_obj_folio(struct folio *folio, void *p)
>  	 * slab->memcg_data.
>  	 */
>  	if (folio_test_slab(folio)) {
> -		struct obj_cgroup **objcgs;
> +		struct slabobj_ext *obj_exts;
>  		struct slab *slab;
>  		unsigned int off;
>  
>  		slab = folio_slab(folio);
> -		objcgs = slab_objcgs(slab);
> -		if (!objcgs)
> +		obj_exts = slab_obj_exts(slab);
> +		if (!obj_exts)
>  			return NULL;
>  
>  		off = obj_to_index(slab->slab_cache, slab, p);
> -		if (objcgs[off])
> -			return obj_cgroup_memcg(objcgs[off]);
> +		if (obj_exts[off].objcg)
> +			return obj_cgroup_memcg(obj_exts[off].objcg);
>  
>  		return NULL;
>  	}
> diff --git a/mm/page_owner.c b/mm/page_owner.c
> index e4c6f3f1695b..fd4af1ad34b8 100644
> --- a/mm/page_owner.c
> +++ b/mm/page_owner.c
> @@ -353,7 +353,7 @@ static inline int print_page_owner_memcg(char *kbuf, size_t count, int ret,
>  	if (!memcg_data)
>  		goto out_unlock;
>  
> -	if (memcg_data & MEMCG_DATA_OBJCGS)
> +	if (memcg_data & MEMCG_DATA_OBJEXTS)
>  		ret += scnprintf(kbuf + ret, count - ret,
>  				"Slab cache page\n");
>  
> diff --git a/mm/slab.h b/mm/slab.h
> index 4ec82bec15ec..c767ce3f0fe2 100644
> --- a/mm/slab.h
> +++ b/mm/slab.h
> @@ -422,36 +422,94 @@ static inline bool kmem_cache_debug_flags(struct kmem_cache *s, slab_flags_t fla
>  	return false;
>  }
>  
> +#ifdef CONFIG_SLAB_OBJ_EXT
> +
> +static inline bool is_kmem_only_obj_ext(void)
> +{
>  #ifdef CONFIG_MEMCG_KMEM
> +	return sizeof(struct slabobj_ext) == sizeof(struct obj_cgroup *);
> +#else
> +	return false;
> +#endif
> +}
> +
>  /*
> - * slab_objcgs - get the object cgroups vector associated with a slab
> + * slab_obj_exts - get the pointer to the slab object extension vector
> + * associated with a slab.
>   * @slab: a pointer to the slab struct
>   *
> - * Returns a pointer to the object cgroups vector associated with the slab,
> + * Returns a pointer to the object extension vector associated with the slab,
>   * or NULL if no such vector has been associated yet.
>   */
> -static inline struct obj_cgroup **slab_objcgs(struct slab *slab)
> +static inline struct slabobj_ext *slab_obj_exts(struct slab *slab)
>  {
>  	unsigned long memcg_data = READ_ONCE(slab->memcg_data);
>  
> -	VM_BUG_ON_PAGE(memcg_data && !(memcg_data & MEMCG_DATA_OBJCGS),
> +	VM_BUG_ON_PAGE(memcg_data && !(memcg_data & MEMCG_DATA_OBJEXTS),
>  							slab_page(slab));
>  	VM_BUG_ON_PAGE(memcg_data & MEMCG_DATA_KMEM, slab_page(slab));
>  
> -	return (struct obj_cgroup **)(memcg_data & ~MEMCG_DATA_FLAGS_MASK);
> +	return (struct slabobj_ext *)(memcg_data & ~MEMCG_DATA_FLAGS_MASK);
>  }
>  
> -int memcg_alloc_slab_cgroups(struct slab *slab, struct kmem_cache *s,
> -				 gfp_t gfp, bool new_slab);
> -void mod_objcg_state(struct obj_cgroup *objcg, struct pglist_data *pgdat,
> -		     enum node_stat_item idx, int nr);
> +int alloc_slab_obj_exts(struct slab *slab, struct kmem_cache *s,
> +			gfp_t gfp, bool new_slab);
>  
> -static inline void memcg_free_slab_cgroups(struct slab *slab)
> +static inline void free_slab_obj_exts(struct slab *slab)
>  {
> -	kfree(slab_objcgs(slab));
> +	struct slabobj_ext *obj_exts;
> +
> +	if (!memcg_kmem_enabled() && is_kmem_only_obj_ext())
> +		return;

Hm, not sure I understand this. I kmem is disabled and is_kmem_only_obj_ext()
is true, shouldn't slab->memcg_data == NULL (always)?

> +
> +	obj_exts = slab_obj_exts(slab);
> +	kfree(obj_exts);
>  	slab->memcg_data = 0;
>  }
>  
> +static inline void prepare_slab_obj_exts_hook(struct kmem_cache *s, gfp_t flags, void *p)
> +{
> +	struct slab *slab;
> +
> +	/* If kmem is the only extension then the vector will be created conditionally */
> +	if (is_kmem_only_obj_ext())
> +		return;
> +
> +	slab = virt_to_slab(p);
> +	if (!slab_obj_exts(slab))
> +		WARN(alloc_slab_obj_exts(slab, s, flags, false),
> +			"%s, %s: Failed to create slab extension vector!\n",
> +			__func__, s->name);
> +}

This looks a bit crypric: the action is wrapped into WARN() and the rest is a set
of (semi-)static checks. Can we, please, invert it? E.g. something like:

if (slab_alloc_tracking_enabled()) {
	slab = virt_to_slab(p);
	if (!slab_obj_exts(slab))
		WARN(alloc_slab_obj_exts(slab, s, flags, false),
		"%s, %s: Failed to create slab extension vector!\n",
		__func__, s->name);
}

The rest looks good to me.

Thank you!
Suren Baghdasaryan Sept. 2, 2022, 12:23 a.m. UTC | #2
On Thu, Sep 1, 2022 at 4:36 PM Roman Gushchin <roman.gushchin@linux.dev> wrote:
>
> On Tue, Aug 30, 2022 at 02:49:00PM -0700, Suren Baghdasaryan wrote:
> > Currently slab pages can store only vectors of obj_cgroup pointers in
> > page->memcg_data. Introduce slabobj_ext structure to allow more data
> > to be stored for each slab object. Wraps obj_cgroup into slabobj_ext
> > to support current functionality while allowing to extend slabobj_ext
> > in the future.
> >
> > Note: ideally the config dependency should be turned the other way around:
> > MEMCG should depend on SLAB_OBJ_EXT and {page|slab|folio}.memcg_data would
> > be renamed to something like {page|slab|folio}.objext_data. However doing
> > this in RFC would introduce considerable churn unrelated to the overall
> > idea, so avoiding this until v1.
>
> Hi Suren!

Hi Roman,

>
> I'd say CONFIG_MEMCG_KMEM and CONFIG_YOUR_NEW_STUFF should both depend on
> SLAB_OBJ_EXT.
> CONFIG_MEMCG_KMEM depend on CONFIG_MEMCG anyway.

Yes, I agree. I wanted to mention here that the current dependency is
incorrect and should be reworked. Having both depending on
SLAB_OBJ_EXT seems like the right approach.

>
> >
> > Signed-off-by: Suren Baghdasaryan <surenb@google.com>
> > ---
> >  include/linux/memcontrol.h |  18 ++++--
> >  init/Kconfig               |   5 ++
> >  mm/kfence/core.c           |   2 +-
> >  mm/memcontrol.c            |  60 ++++++++++---------
> >  mm/page_owner.c            |   2 +-
> >  mm/slab.h                  | 119 +++++++++++++++++++++++++------------
> >  6 files changed, 131 insertions(+), 75 deletions(-)
> >
> > diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h
> > index 6257867fbf95..315399f77173 100644
> > --- a/include/linux/memcontrol.h
> > +++ b/include/linux/memcontrol.h
> > @@ -227,6 +227,14 @@ struct obj_cgroup {
> >       };
> >  };
> >
> > +/*
> > + * Extended information for slab objects stored as an array in page->memcg_data
> > + * if MEMCG_DATA_OBJEXTS is set.
> > + */
> > +struct slabobj_ext {
> > +     struct obj_cgroup *objcg;
> > +} __aligned(8);
>
> Why do we need this aligment requirement?

To save space by avoiding padding, however, all members today will be
pointers, so it's meaningless and we can safely drop it.

>
> > +
> >  /*
> >   * The memory controller data structure. The memory controller controls both
> >   * page cache and RSS per cgroup. We would eventually like to provide
> > @@ -363,7 +371,7 @@ extern struct mem_cgroup *root_mem_cgroup;
> >
> >  enum page_memcg_data_flags {
> >       /* page->memcg_data is a pointer to an objcgs vector */
> > -     MEMCG_DATA_OBJCGS = (1UL << 0),
> > +     MEMCG_DATA_OBJEXTS = (1UL << 0),
> >       /* page has been accounted as a non-slab kernel page */
> >       MEMCG_DATA_KMEM = (1UL << 1),
> >       /* the next bit after the last actual flag */
> > @@ -401,7 +409,7 @@ static inline struct mem_cgroup *__folio_memcg(struct folio *folio)
> >       unsigned long memcg_data = folio->memcg_data;
> >
> >       VM_BUG_ON_FOLIO(folio_test_slab(folio), folio);
> > -     VM_BUG_ON_FOLIO(memcg_data & MEMCG_DATA_OBJCGS, folio);
> > +     VM_BUG_ON_FOLIO(memcg_data & MEMCG_DATA_OBJEXTS, folio);
> >       VM_BUG_ON_FOLIO(memcg_data & MEMCG_DATA_KMEM, folio);
> >
> >       return (struct mem_cgroup *)(memcg_data & ~MEMCG_DATA_FLAGS_MASK);
> > @@ -422,7 +430,7 @@ static inline struct obj_cgroup *__folio_objcg(struct folio *folio)
> >       unsigned long memcg_data = folio->memcg_data;
> >
> >       VM_BUG_ON_FOLIO(folio_test_slab(folio), folio);
> > -     VM_BUG_ON_FOLIO(memcg_data & MEMCG_DATA_OBJCGS, folio);
> > +     VM_BUG_ON_FOLIO(memcg_data & MEMCG_DATA_OBJEXTS, folio);
> >       VM_BUG_ON_FOLIO(!(memcg_data & MEMCG_DATA_KMEM), folio);
> >
> >       return (struct obj_cgroup *)(memcg_data & ~MEMCG_DATA_FLAGS_MASK);
> > @@ -517,7 +525,7 @@ static inline struct mem_cgroup *page_memcg_check(struct page *page)
> >        */
> >       unsigned long memcg_data = READ_ONCE(page->memcg_data);
> >
> > -     if (memcg_data & MEMCG_DATA_OBJCGS)
> > +     if (memcg_data & MEMCG_DATA_OBJEXTS)
> >               return NULL;
> >
> >       if (memcg_data & MEMCG_DATA_KMEM) {
> > @@ -556,7 +564,7 @@ static inline struct mem_cgroup *get_mem_cgroup_from_objcg(struct obj_cgroup *ob
> >  static inline bool folio_memcg_kmem(struct folio *folio)
> >  {
> >       VM_BUG_ON_PGFLAGS(PageTail(&folio->page), &folio->page);
> > -     VM_BUG_ON_FOLIO(folio->memcg_data & MEMCG_DATA_OBJCGS, folio);
> > +     VM_BUG_ON_FOLIO(folio->memcg_data & MEMCG_DATA_OBJEXTS, folio);
> >       return folio->memcg_data & MEMCG_DATA_KMEM;
> >  }
> >
> > diff --git a/init/Kconfig b/init/Kconfig
> > index 532362fcfe31..82396d7a2717 100644
> > --- a/init/Kconfig
> > +++ b/init/Kconfig
> > @@ -958,6 +958,10 @@ config MEMCG
> >       help
> >         Provides control over the memory footprint of tasks in a cgroup.
> >
> > +config SLAB_OBJ_EXT
> > +     bool
> > +     depends on MEMCG
> > +
> >  config MEMCG_SWAP
> >       bool
> >       depends on MEMCG && SWAP
> > @@ -966,6 +970,7 @@ config MEMCG_SWAP
> >  config MEMCG_KMEM
> >       bool
> >       depends on MEMCG && !SLOB
> > +     select SLAB_OBJ_EXT
> >       default y
> >
> >  config BLK_CGROUP
> > diff --git a/mm/kfence/core.c b/mm/kfence/core.c
> > index c252081b11df..c0958e4a32e2 100644
> > --- a/mm/kfence/core.c
> > +++ b/mm/kfence/core.c
> > @@ -569,7 +569,7 @@ static unsigned long kfence_init_pool(void)
> >               __folio_set_slab(slab_folio(slab));
> >  #ifdef CONFIG_MEMCG
> >               slab->memcg_data = (unsigned long)&kfence_metadata[i / 2 - 1].objcg |
> > -                                MEMCG_DATA_OBJCGS;
> > +                                MEMCG_DATA_OBJEXTS;
> >  #endif
> >       }
> >
> > diff --git a/mm/memcontrol.c b/mm/memcontrol.c
> > index b69979c9ced5..3f407ef2f3f1 100644
> > --- a/mm/memcontrol.c
> > +++ b/mm/memcontrol.c
> > @@ -2793,7 +2793,7 @@ static void commit_charge(struct folio *folio, struct mem_cgroup *memcg)
> >       folio->memcg_data = (unsigned long)memcg;
> >  }
> >
> > -#ifdef CONFIG_MEMCG_KMEM
> > +#ifdef CONFIG_SLAB_OBJ_EXT
> >  /*
> >   * The allocated objcg pointers array is not accounted directly.
> >   * Moreover, it should not come from DMA buffer and is not readily
> > @@ -2801,38 +2801,20 @@ static void commit_charge(struct folio *folio, struct mem_cgroup *memcg)
> >   */
> >  #define OBJCGS_CLEAR_MASK    (__GFP_DMA | __GFP_RECLAIMABLE | __GFP_ACCOUNT)
> >
> > -/*
> > - * mod_objcg_mlstate() may be called with irq enabled, so
> > - * mod_memcg_lruvec_state() should be used.
> > - */
> > -static inline void mod_objcg_mlstate(struct obj_cgroup *objcg,
> > -                                  struct pglist_data *pgdat,
> > -                                  enum node_stat_item idx, int nr)
> > -{
> > -     struct mem_cgroup *memcg;
> > -     struct lruvec *lruvec;
> > -
> > -     rcu_read_lock();
> > -     memcg = obj_cgroup_memcg(objcg);
> > -     lruvec = mem_cgroup_lruvec(memcg, pgdat);
> > -     mod_memcg_lruvec_state(lruvec, idx, nr);
> > -     rcu_read_unlock();
> > -}
> > -
> > -int memcg_alloc_slab_cgroups(struct slab *slab, struct kmem_cache *s,
> > -                              gfp_t gfp, bool new_slab)
> > +int alloc_slab_obj_exts(struct slab *slab, struct kmem_cache *s,
> > +                     gfp_t gfp, bool new_slab)
> >  {
> >       unsigned int objects = objs_per_slab(s, slab);
> >       unsigned long memcg_data;
> >       void *vec;
> >
> >       gfp &= ~OBJCGS_CLEAR_MASK;
> > -     vec = kcalloc_node(objects, sizeof(struct obj_cgroup *), gfp,
> > +     vec = kcalloc_node(objects, sizeof(struct slabobj_ext), gfp,
> >                          slab_nid(slab));
> >       if (!vec)
> >               return -ENOMEM;
> >
> > -     memcg_data = (unsigned long) vec | MEMCG_DATA_OBJCGS;
> > +     memcg_data = (unsigned long) vec | MEMCG_DATA_OBJEXTS;
> >       if (new_slab) {
> >               /*
> >                * If the slab is brand new and nobody can yet access its
> > @@ -2843,7 +2825,7 @@ int memcg_alloc_slab_cgroups(struct slab *slab, struct kmem_cache *s,
> >       } else if (cmpxchg(&slab->memcg_data, 0, memcg_data)) {
> >               /*
> >                * If the slab is already in use, somebody can allocate and
> > -              * assign obj_cgroups in parallel. In this case the existing
> > +              * assign slabobj_exts in parallel. In this case the existing
> >                * objcg vector should be reused.
> >                */
> >               kfree(vec);
> > @@ -2853,6 +2835,26 @@ int memcg_alloc_slab_cgroups(struct slab *slab, struct kmem_cache *s,
> >       kmemleak_not_leak(vec);
> >       return 0;
> >  }
> > +#endif /* CONFIG_SLAB_OBJ_EXT */
> > +
> > +#ifdef CONFIG_MEMCG_KMEM
> > +/*
> > + * mod_objcg_mlstate() may be called with irq enabled, so
> > + * mod_memcg_lruvec_state() should be used.
> > + */
> > +static inline void mod_objcg_mlstate(struct obj_cgroup *objcg,
> > +                                  struct pglist_data *pgdat,
> > +                                  enum node_stat_item idx, int nr)
> > +{
> > +     struct mem_cgroup *memcg;
> > +     struct lruvec *lruvec;
> > +
> > +     rcu_read_lock();
> > +     memcg = obj_cgroup_memcg(objcg);
> > +     lruvec = mem_cgroup_lruvec(memcg, pgdat);
> > +     mod_memcg_lruvec_state(lruvec, idx, nr);
> > +     rcu_read_unlock();
> > +}
> >
> >  static __always_inline
> >  struct mem_cgroup *mem_cgroup_from_obj_folio(struct folio *folio, void *p)
> > @@ -2863,18 +2865,18 @@ struct mem_cgroup *mem_cgroup_from_obj_folio(struct folio *folio, void *p)
> >        * slab->memcg_data.
> >        */
> >       if (folio_test_slab(folio)) {
> > -             struct obj_cgroup **objcgs;
> > +             struct slabobj_ext *obj_exts;
> >               struct slab *slab;
> >               unsigned int off;
> >
> >               slab = folio_slab(folio);
> > -             objcgs = slab_objcgs(slab);
> > -             if (!objcgs)
> > +             obj_exts = slab_obj_exts(slab);
> > +             if (!obj_exts)
> >                       return NULL;
> >
> >               off = obj_to_index(slab->slab_cache, slab, p);
> > -             if (objcgs[off])
> > -                     return obj_cgroup_memcg(objcgs[off]);
> > +             if (obj_exts[off].objcg)
> > +                     return obj_cgroup_memcg(obj_exts[off].objcg);
> >
> >               return NULL;
> >       }
> > diff --git a/mm/page_owner.c b/mm/page_owner.c
> > index e4c6f3f1695b..fd4af1ad34b8 100644
> > --- a/mm/page_owner.c
> > +++ b/mm/page_owner.c
> > @@ -353,7 +353,7 @@ static inline int print_page_owner_memcg(char *kbuf, size_t count, int ret,
> >       if (!memcg_data)
> >               goto out_unlock;
> >
> > -     if (memcg_data & MEMCG_DATA_OBJCGS)
> > +     if (memcg_data & MEMCG_DATA_OBJEXTS)
> >               ret += scnprintf(kbuf + ret, count - ret,
> >                               "Slab cache page\n");
> >
> > diff --git a/mm/slab.h b/mm/slab.h
> > index 4ec82bec15ec..c767ce3f0fe2 100644
> > --- a/mm/slab.h
> > +++ b/mm/slab.h
> > @@ -422,36 +422,94 @@ static inline bool kmem_cache_debug_flags(struct kmem_cache *s, slab_flags_t fla
> >       return false;
> >  }
> >
> > +#ifdef CONFIG_SLAB_OBJ_EXT
> > +
> > +static inline bool is_kmem_only_obj_ext(void)
> > +{
> >  #ifdef CONFIG_MEMCG_KMEM
> > +     return sizeof(struct slabobj_ext) == sizeof(struct obj_cgroup *);
> > +#else
> > +     return false;
> > +#endif
> > +}
> > +
> >  /*
> > - * slab_objcgs - get the object cgroups vector associated with a slab
> > + * slab_obj_exts - get the pointer to the slab object extension vector
> > + * associated with a slab.
> >   * @slab: a pointer to the slab struct
> >   *
> > - * Returns a pointer to the object cgroups vector associated with the slab,
> > + * Returns a pointer to the object extension vector associated with the slab,
> >   * or NULL if no such vector has been associated yet.
> >   */
> > -static inline struct obj_cgroup **slab_objcgs(struct slab *slab)
> > +static inline struct slabobj_ext *slab_obj_exts(struct slab *slab)
> >  {
> >       unsigned long memcg_data = READ_ONCE(slab->memcg_data);
> >
> > -     VM_BUG_ON_PAGE(memcg_data && !(memcg_data & MEMCG_DATA_OBJCGS),
> > +     VM_BUG_ON_PAGE(memcg_data && !(memcg_data & MEMCG_DATA_OBJEXTS),
> >                                                       slab_page(slab));
> >       VM_BUG_ON_PAGE(memcg_data & MEMCG_DATA_KMEM, slab_page(slab));
> >
> > -     return (struct obj_cgroup **)(memcg_data & ~MEMCG_DATA_FLAGS_MASK);
> > +     return (struct slabobj_ext *)(memcg_data & ~MEMCG_DATA_FLAGS_MASK);
> >  }
> >
> > -int memcg_alloc_slab_cgroups(struct slab *slab, struct kmem_cache *s,
> > -                              gfp_t gfp, bool new_slab);
> > -void mod_objcg_state(struct obj_cgroup *objcg, struct pglist_data *pgdat,
> > -                  enum node_stat_item idx, int nr);
> > +int alloc_slab_obj_exts(struct slab *slab, struct kmem_cache *s,
> > +                     gfp_t gfp, bool new_slab);
> >
> > -static inline void memcg_free_slab_cgroups(struct slab *slab)
> > +static inline void free_slab_obj_exts(struct slab *slab)
> >  {
> > -     kfree(slab_objcgs(slab));
> > +     struct slabobj_ext *obj_exts;
> > +
> > +     if (!memcg_kmem_enabled() && is_kmem_only_obj_ext())
> > +             return;
>
> Hm, not sure I understand this. I kmem is disabled and is_kmem_only_obj_ext()
> is true, shouldn't slab->memcg_data == NULL (always)?

So, the logic was to skip freeing when the only possible objects in
slab->memcg_data are "struct obj_cgroup" and kmem is disabled.
Otherwise there are other objects stored in slab->memcg_data which
have to be freed. Did I make it more complicated than it should have
been?

>
> > +
> > +     obj_exts = slab_obj_exts(slab);
> > +     kfree(obj_exts);
> >       slab->memcg_data = 0;
> >  }
> >
> > +static inline void prepare_slab_obj_exts_hook(struct kmem_cache *s, gfp_t flags, void *p)
> > +{
> > +     struct slab *slab;
> > +
> > +     /* If kmem is the only extension then the vector will be created conditionally */
> > +     if (is_kmem_only_obj_ext())
> > +             return;
> > +
> > +     slab = virt_to_slab(p);
> > +     if (!slab_obj_exts(slab))
> > +             WARN(alloc_slab_obj_exts(slab, s, flags, false),
> > +                     "%s, %s: Failed to create slab extension vector!\n",
> > +                     __func__, s->name);
> > +}
>
> This looks a bit crypric: the action is wrapped into WARN() and the rest is a set
> of (semi-)static checks. Can we, please, invert it? E.g. something like:
>
> if (slab_alloc_tracking_enabled()) {
>         slab = virt_to_slab(p);
>         if (!slab_obj_exts(slab))
>                 WARN(alloc_slab_obj_exts(slab, s, flags, false),
>                 "%s, %s: Failed to create slab extension vector!\n",
>                 __func__, s->name);
> }

Yeah, this is much more readable. Thanks for the suggestion and for
reviewing the code!

>
> The rest looks good to me.
>
> Thank you!
>
> --
> To unsubscribe from this group and stop receiving emails from it, send an email to kernel-team+unsubscribe@android.com.
>
diff mbox series

Patch

diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h
index 6257867fbf95..315399f77173 100644
--- a/include/linux/memcontrol.h
+++ b/include/linux/memcontrol.h
@@ -227,6 +227,14 @@  struct obj_cgroup {
 	};
 };
 
+/*
+ * Extended information for slab objects stored as an array in page->memcg_data
+ * if MEMCG_DATA_OBJEXTS is set.
+ */
+struct slabobj_ext {
+	struct obj_cgroup *objcg;
+} __aligned(8);
+
 /*
  * The memory controller data structure. The memory controller controls both
  * page cache and RSS per cgroup. We would eventually like to provide
@@ -363,7 +371,7 @@  extern struct mem_cgroup *root_mem_cgroup;
 
 enum page_memcg_data_flags {
 	/* page->memcg_data is a pointer to an objcgs vector */
-	MEMCG_DATA_OBJCGS = (1UL << 0),
+	MEMCG_DATA_OBJEXTS = (1UL << 0),
 	/* page has been accounted as a non-slab kernel page */
 	MEMCG_DATA_KMEM = (1UL << 1),
 	/* the next bit after the last actual flag */
@@ -401,7 +409,7 @@  static inline struct mem_cgroup *__folio_memcg(struct folio *folio)
 	unsigned long memcg_data = folio->memcg_data;
 
 	VM_BUG_ON_FOLIO(folio_test_slab(folio), folio);
-	VM_BUG_ON_FOLIO(memcg_data & MEMCG_DATA_OBJCGS, folio);
+	VM_BUG_ON_FOLIO(memcg_data & MEMCG_DATA_OBJEXTS, folio);
 	VM_BUG_ON_FOLIO(memcg_data & MEMCG_DATA_KMEM, folio);
 
 	return (struct mem_cgroup *)(memcg_data & ~MEMCG_DATA_FLAGS_MASK);
@@ -422,7 +430,7 @@  static inline struct obj_cgroup *__folio_objcg(struct folio *folio)
 	unsigned long memcg_data = folio->memcg_data;
 
 	VM_BUG_ON_FOLIO(folio_test_slab(folio), folio);
-	VM_BUG_ON_FOLIO(memcg_data & MEMCG_DATA_OBJCGS, folio);
+	VM_BUG_ON_FOLIO(memcg_data & MEMCG_DATA_OBJEXTS, folio);
 	VM_BUG_ON_FOLIO(!(memcg_data & MEMCG_DATA_KMEM), folio);
 
 	return (struct obj_cgroup *)(memcg_data & ~MEMCG_DATA_FLAGS_MASK);
@@ -517,7 +525,7 @@  static inline struct mem_cgroup *page_memcg_check(struct page *page)
 	 */
 	unsigned long memcg_data = READ_ONCE(page->memcg_data);
 
-	if (memcg_data & MEMCG_DATA_OBJCGS)
+	if (memcg_data & MEMCG_DATA_OBJEXTS)
 		return NULL;
 
 	if (memcg_data & MEMCG_DATA_KMEM) {
@@ -556,7 +564,7 @@  static inline struct mem_cgroup *get_mem_cgroup_from_objcg(struct obj_cgroup *ob
 static inline bool folio_memcg_kmem(struct folio *folio)
 {
 	VM_BUG_ON_PGFLAGS(PageTail(&folio->page), &folio->page);
-	VM_BUG_ON_FOLIO(folio->memcg_data & MEMCG_DATA_OBJCGS, folio);
+	VM_BUG_ON_FOLIO(folio->memcg_data & MEMCG_DATA_OBJEXTS, folio);
 	return folio->memcg_data & MEMCG_DATA_KMEM;
 }
 
diff --git a/init/Kconfig b/init/Kconfig
index 532362fcfe31..82396d7a2717 100644
--- a/init/Kconfig
+++ b/init/Kconfig
@@ -958,6 +958,10 @@  config MEMCG
 	help
 	  Provides control over the memory footprint of tasks in a cgroup.
 
+config SLAB_OBJ_EXT
+	bool
+	depends on MEMCG
+
 config MEMCG_SWAP
 	bool
 	depends on MEMCG && SWAP
@@ -966,6 +970,7 @@  config MEMCG_SWAP
 config MEMCG_KMEM
 	bool
 	depends on MEMCG && !SLOB
+	select SLAB_OBJ_EXT
 	default y
 
 config BLK_CGROUP
diff --git a/mm/kfence/core.c b/mm/kfence/core.c
index c252081b11df..c0958e4a32e2 100644
--- a/mm/kfence/core.c
+++ b/mm/kfence/core.c
@@ -569,7 +569,7 @@  static unsigned long kfence_init_pool(void)
 		__folio_set_slab(slab_folio(slab));
 #ifdef CONFIG_MEMCG
 		slab->memcg_data = (unsigned long)&kfence_metadata[i / 2 - 1].objcg |
-				   MEMCG_DATA_OBJCGS;
+				   MEMCG_DATA_OBJEXTS;
 #endif
 	}
 
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index b69979c9ced5..3f407ef2f3f1 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -2793,7 +2793,7 @@  static void commit_charge(struct folio *folio, struct mem_cgroup *memcg)
 	folio->memcg_data = (unsigned long)memcg;
 }
 
-#ifdef CONFIG_MEMCG_KMEM
+#ifdef CONFIG_SLAB_OBJ_EXT
 /*
  * The allocated objcg pointers array is not accounted directly.
  * Moreover, it should not come from DMA buffer and is not readily
@@ -2801,38 +2801,20 @@  static void commit_charge(struct folio *folio, struct mem_cgroup *memcg)
  */
 #define OBJCGS_CLEAR_MASK	(__GFP_DMA | __GFP_RECLAIMABLE | __GFP_ACCOUNT)
 
-/*
- * mod_objcg_mlstate() may be called with irq enabled, so
- * mod_memcg_lruvec_state() should be used.
- */
-static inline void mod_objcg_mlstate(struct obj_cgroup *objcg,
-				     struct pglist_data *pgdat,
-				     enum node_stat_item idx, int nr)
-{
-	struct mem_cgroup *memcg;
-	struct lruvec *lruvec;
-
-	rcu_read_lock();
-	memcg = obj_cgroup_memcg(objcg);
-	lruvec = mem_cgroup_lruvec(memcg, pgdat);
-	mod_memcg_lruvec_state(lruvec, idx, nr);
-	rcu_read_unlock();
-}
-
-int memcg_alloc_slab_cgroups(struct slab *slab, struct kmem_cache *s,
-				 gfp_t gfp, bool new_slab)
+int alloc_slab_obj_exts(struct slab *slab, struct kmem_cache *s,
+			gfp_t gfp, bool new_slab)
 {
 	unsigned int objects = objs_per_slab(s, slab);
 	unsigned long memcg_data;
 	void *vec;
 
 	gfp &= ~OBJCGS_CLEAR_MASK;
-	vec = kcalloc_node(objects, sizeof(struct obj_cgroup *), gfp,
+	vec = kcalloc_node(objects, sizeof(struct slabobj_ext), gfp,
 			   slab_nid(slab));
 	if (!vec)
 		return -ENOMEM;
 
-	memcg_data = (unsigned long) vec | MEMCG_DATA_OBJCGS;
+	memcg_data = (unsigned long) vec | MEMCG_DATA_OBJEXTS;
 	if (new_slab) {
 		/*
 		 * If the slab is brand new and nobody can yet access its
@@ -2843,7 +2825,7 @@  int memcg_alloc_slab_cgroups(struct slab *slab, struct kmem_cache *s,
 	} else if (cmpxchg(&slab->memcg_data, 0, memcg_data)) {
 		/*
 		 * If the slab is already in use, somebody can allocate and
-		 * assign obj_cgroups in parallel. In this case the existing
+		 * assign slabobj_exts in parallel. In this case the existing
 		 * objcg vector should be reused.
 		 */
 		kfree(vec);
@@ -2853,6 +2835,26 @@  int memcg_alloc_slab_cgroups(struct slab *slab, struct kmem_cache *s,
 	kmemleak_not_leak(vec);
 	return 0;
 }
+#endif /* CONFIG_SLAB_OBJ_EXT */
+
+#ifdef CONFIG_MEMCG_KMEM
+/*
+ * mod_objcg_mlstate() may be called with irq enabled, so
+ * mod_memcg_lruvec_state() should be used.
+ */
+static inline void mod_objcg_mlstate(struct obj_cgroup *objcg,
+				     struct pglist_data *pgdat,
+				     enum node_stat_item idx, int nr)
+{
+	struct mem_cgroup *memcg;
+	struct lruvec *lruvec;
+
+	rcu_read_lock();
+	memcg = obj_cgroup_memcg(objcg);
+	lruvec = mem_cgroup_lruvec(memcg, pgdat);
+	mod_memcg_lruvec_state(lruvec, idx, nr);
+	rcu_read_unlock();
+}
 
 static __always_inline
 struct mem_cgroup *mem_cgroup_from_obj_folio(struct folio *folio, void *p)
@@ -2863,18 +2865,18 @@  struct mem_cgroup *mem_cgroup_from_obj_folio(struct folio *folio, void *p)
 	 * slab->memcg_data.
 	 */
 	if (folio_test_slab(folio)) {
-		struct obj_cgroup **objcgs;
+		struct slabobj_ext *obj_exts;
 		struct slab *slab;
 		unsigned int off;
 
 		slab = folio_slab(folio);
-		objcgs = slab_objcgs(slab);
-		if (!objcgs)
+		obj_exts = slab_obj_exts(slab);
+		if (!obj_exts)
 			return NULL;
 
 		off = obj_to_index(slab->slab_cache, slab, p);
-		if (objcgs[off])
-			return obj_cgroup_memcg(objcgs[off]);
+		if (obj_exts[off].objcg)
+			return obj_cgroup_memcg(obj_exts[off].objcg);
 
 		return NULL;
 	}
diff --git a/mm/page_owner.c b/mm/page_owner.c
index e4c6f3f1695b..fd4af1ad34b8 100644
--- a/mm/page_owner.c
+++ b/mm/page_owner.c
@@ -353,7 +353,7 @@  static inline int print_page_owner_memcg(char *kbuf, size_t count, int ret,
 	if (!memcg_data)
 		goto out_unlock;
 
-	if (memcg_data & MEMCG_DATA_OBJCGS)
+	if (memcg_data & MEMCG_DATA_OBJEXTS)
 		ret += scnprintf(kbuf + ret, count - ret,
 				"Slab cache page\n");
 
diff --git a/mm/slab.h b/mm/slab.h
index 4ec82bec15ec..c767ce3f0fe2 100644
--- a/mm/slab.h
+++ b/mm/slab.h
@@ -422,36 +422,94 @@  static inline bool kmem_cache_debug_flags(struct kmem_cache *s, slab_flags_t fla
 	return false;
 }
 
+#ifdef CONFIG_SLAB_OBJ_EXT
+
+static inline bool is_kmem_only_obj_ext(void)
+{
 #ifdef CONFIG_MEMCG_KMEM
+	return sizeof(struct slabobj_ext) == sizeof(struct obj_cgroup *);
+#else
+	return false;
+#endif
+}
+
 /*
- * slab_objcgs - get the object cgroups vector associated with a slab
+ * slab_obj_exts - get the pointer to the slab object extension vector
+ * associated with a slab.
  * @slab: a pointer to the slab struct
  *
- * Returns a pointer to the object cgroups vector associated with the slab,
+ * Returns a pointer to the object extension vector associated with the slab,
  * or NULL if no such vector has been associated yet.
  */
-static inline struct obj_cgroup **slab_objcgs(struct slab *slab)
+static inline struct slabobj_ext *slab_obj_exts(struct slab *slab)
 {
 	unsigned long memcg_data = READ_ONCE(slab->memcg_data);
 
-	VM_BUG_ON_PAGE(memcg_data && !(memcg_data & MEMCG_DATA_OBJCGS),
+	VM_BUG_ON_PAGE(memcg_data && !(memcg_data & MEMCG_DATA_OBJEXTS),
 							slab_page(slab));
 	VM_BUG_ON_PAGE(memcg_data & MEMCG_DATA_KMEM, slab_page(slab));
 
-	return (struct obj_cgroup **)(memcg_data & ~MEMCG_DATA_FLAGS_MASK);
+	return (struct slabobj_ext *)(memcg_data & ~MEMCG_DATA_FLAGS_MASK);
 }
 
-int memcg_alloc_slab_cgroups(struct slab *slab, struct kmem_cache *s,
-				 gfp_t gfp, bool new_slab);
-void mod_objcg_state(struct obj_cgroup *objcg, struct pglist_data *pgdat,
-		     enum node_stat_item idx, int nr);
+int alloc_slab_obj_exts(struct slab *slab, struct kmem_cache *s,
+			gfp_t gfp, bool new_slab);
 
-static inline void memcg_free_slab_cgroups(struct slab *slab)
+static inline void free_slab_obj_exts(struct slab *slab)
 {
-	kfree(slab_objcgs(slab));
+	struct slabobj_ext *obj_exts;
+
+	if (!memcg_kmem_enabled() && is_kmem_only_obj_ext())
+		return;
+
+	obj_exts = slab_obj_exts(slab);
+	kfree(obj_exts);
 	slab->memcg_data = 0;
 }
 
+static inline void prepare_slab_obj_exts_hook(struct kmem_cache *s, gfp_t flags, void *p)
+{
+	struct slab *slab;
+
+	/* If kmem is the only extension then the vector will be created conditionally */
+	if (is_kmem_only_obj_ext())
+		return;
+
+	slab = virt_to_slab(p);
+	if (!slab_obj_exts(slab))
+		WARN(alloc_slab_obj_exts(slab, s, flags, false),
+			"%s, %s: Failed to create slab extension vector!\n",
+			__func__, s->name);
+}
+
+#else /* CONFIG_SLAB_OBJ_EXT */
+
+static inline struct slabobj_ext *slab_obj_exts(struct slab *slab)
+{
+	return NULL;
+}
+
+static inline int alloc_slab_obj_exts(struct slab *slab,
+				      struct kmem_cache *s, gfp_t gfp,
+				      bool new_slab)
+{
+	return 0;
+}
+
+static inline void free_slab_obj_exts(struct slab *slab)
+{
+}
+
+static inline void prepare_slab_obj_exts_hook(struct kmem_cache *s, gfp_t flags, void *p)
+{
+}
+
+#endif /* CONFIG_SLAB_OBJ_EXT */
+
+#ifdef CONFIG_MEMCG_KMEM
+void mod_objcg_state(struct obj_cgroup *objcg, struct pglist_data *pgdat,
+		     enum node_stat_item idx, int nr);
+
 static inline size_t obj_full_size(struct kmem_cache *s)
 {
 	/*
@@ -519,16 +577,15 @@  static inline void memcg_slab_post_alloc_hook(struct kmem_cache *s,
 		if (likely(p[i])) {
 			slab = virt_to_slab(p[i]);
 
-			if (!slab_objcgs(slab) &&
-			    memcg_alloc_slab_cgroups(slab, s, flags,
-							 false)) {
+			if (!slab_obj_exts(slab) &&
+			    alloc_slab_obj_exts(slab, s, flags, false)) {
 				obj_cgroup_uncharge(objcg, obj_full_size(s));
 				continue;
 			}
 
 			off = obj_to_index(s, slab, p[i]);
 			obj_cgroup_get(objcg);
-			slab_objcgs(slab)[off] = objcg;
+			slab_obj_exts(slab)[off].objcg = objcg;
 			mod_objcg_state(objcg, slab_pgdat(slab),
 					cache_vmstat_idx(s), obj_full_size(s));
 		} else {
@@ -541,14 +598,14 @@  static inline void memcg_slab_post_alloc_hook(struct kmem_cache *s,
 static inline void memcg_slab_free_hook(struct kmem_cache *s, struct slab *slab,
 					void **p, int objects)
 {
-	struct obj_cgroup **objcgs;
+	struct slabobj_ext *obj_exts;
 	int i;
 
 	if (!memcg_kmem_enabled())
 		return;
 
-	objcgs = slab_objcgs(slab);
-	if (!objcgs)
+	obj_exts = slab_obj_exts(slab);
+	if (!obj_exts)
 		return;
 
 	for (i = 0; i < objects; i++) {
@@ -556,11 +613,11 @@  static inline void memcg_slab_free_hook(struct kmem_cache *s, struct slab *slab,
 		unsigned int off;
 
 		off = obj_to_index(s, slab, p[i]);
-		objcg = objcgs[off];
+		objcg = obj_exts[off].objcg;
 		if (!objcg)
 			continue;
 
-		objcgs[off] = NULL;
+		obj_exts[off].objcg = NULL;
 		obj_cgroup_uncharge(objcg, obj_full_size(s));
 		mod_objcg_state(objcg, slab_pgdat(slab), cache_vmstat_idx(s),
 				-obj_full_size(s));
@@ -569,27 +626,11 @@  static inline void memcg_slab_free_hook(struct kmem_cache *s, struct slab *slab,
 }
 
 #else /* CONFIG_MEMCG_KMEM */
-static inline struct obj_cgroup **slab_objcgs(struct slab *slab)
-{
-	return NULL;
-}
-
 static inline struct mem_cgroup *memcg_from_slab_obj(void *ptr)
 {
 	return NULL;
 }
 
-static inline int memcg_alloc_slab_cgroups(struct slab *slab,
-					       struct kmem_cache *s, gfp_t gfp,
-					       bool new_slab)
-{
-	return 0;
-}
-
-static inline void memcg_free_slab_cgroups(struct slab *slab)
-{
-}
-
 static inline bool memcg_slab_pre_alloc_hook(struct kmem_cache *s,
 					     struct list_lru *lru,
 					     struct obj_cgroup **objcgp,
@@ -627,7 +668,7 @@  static __always_inline void account_slab(struct slab *slab, int order,
 					 struct kmem_cache *s, gfp_t gfp)
 {
 	if (memcg_kmem_enabled() && (s->flags & SLAB_ACCOUNT))
-		memcg_alloc_slab_cgroups(slab, s, gfp, true);
+		alloc_slab_obj_exts(slab, s, gfp, true);
 
 	mod_node_page_state(slab_pgdat(slab), cache_vmstat_idx(s),
 			    PAGE_SIZE << order);
@@ -636,8 +677,7 @@  static __always_inline void account_slab(struct slab *slab, int order,
 static __always_inline void unaccount_slab(struct slab *slab, int order,
 					   struct kmem_cache *s)
 {
-	if (memcg_kmem_enabled())
-		memcg_free_slab_cgroups(slab);
+	free_slab_obj_exts(slab);
 
 	mod_node_page_state(slab_pgdat(slab), cache_vmstat_idx(s),
 			    -(PAGE_SIZE << order));
@@ -729,6 +769,7 @@  static inline void slab_post_alloc_hook(struct kmem_cache *s,
 			memset(p[i], 0, s->object_size);
 		kmemleak_alloc_recursive(p[i], s->object_size, 1,
 					 s->flags, flags);
+		prepare_slab_obj_exts_hook(s, flags, p[i]);
 	}
 
 	memcg_slab_post_alloc_hook(s, objcg, flags, size, p);