diff mbox series

[RFC,bpf-next,3/5] bpf: pin, translate, and unpin __kptr_user from syscalls.

Message ID 20240807235755.1435806-4-thinker.li@gmail.com (mailing list archive)
State Superseded
Delegated to: BPF
Headers show
Series Share user memory to BPF program through task storage map. | expand

Checks

Context Check Description
bpf/vmtest-bpf-next-PR success PR summary
bpf/vmtest-bpf-next-VM_Test-0 success Logs for Lint
bpf/vmtest-bpf-next-VM_Test-3 success Logs for Validate matrix.py
bpf/vmtest-bpf-next-VM_Test-2 success Logs for Unittests
bpf/vmtest-bpf-next-VM_Test-1 success Logs for ShellCheck
bpf/vmtest-bpf-next-VM_Test-5 success Logs for aarch64-gcc / build-release
bpf/vmtest-bpf-next-VM_Test-19 success Logs for x86_64-gcc / build / build for x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-13 success Logs for s390x-gcc / test (test_maps, false, 360) / test_maps on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-6 success Logs for aarch64-gcc / test (test_maps, false, 360) / test_maps on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-11 success Logs for s390x-gcc / build / build for s390x with gcc
bpf/vmtest-bpf-next-VM_Test-12 success Logs for s390x-gcc / build-release
bpf/vmtest-bpf-next-VM_Test-16 success Logs for s390x-gcc / test (test_verifier, false, 360) / test_verifier on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-9 success Logs for aarch64-gcc / test (test_verifier, false, 360) / test_verifier on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-17 success Logs for s390x-gcc / veristat
bpf/vmtest-bpf-next-VM_Test-18 success Logs for set-matrix
bpf/vmtest-bpf-next-VM_Test-20 success Logs for x86_64-gcc / build-release
bpf/vmtest-bpf-next-VM_Test-21 success Logs for x86_64-gcc / test (test_maps, false, 360) / test_maps on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-4 success Logs for aarch64-gcc / build / build for aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-10 success Logs for aarch64-gcc / veristat
bpf/vmtest-bpf-next-VM_Test-24 success Logs for x86_64-gcc / test (test_progs_no_alu32_parallel, true, 30) / test_progs_no_alu32_parallel on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-25 success Logs for x86_64-gcc / test (test_progs_parallel, true, 30) / test_progs_parallel on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-26 success Logs for x86_64-gcc / test (test_verifier, false, 360) / test_verifier on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-27 success Logs for x86_64-gcc / veristat / veristat on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-28 success Logs for x86_64-llvm-17 / build / build for x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-29 success Logs for x86_64-llvm-17 / build-release / build for x86_64 with llvm-17-O2
bpf/vmtest-bpf-next-VM_Test-30 success Logs for x86_64-llvm-17 / test (test_maps, false, 360) / test_maps on x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-31 success Logs for x86_64-llvm-17 / test (test_progs, false, 360) / test_progs on x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-32 success Logs for x86_64-llvm-17 / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-33 success Logs for x86_64-llvm-17 / test (test_verifier, false, 360) / test_verifier on x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-34 success Logs for x86_64-llvm-17 / veristat
bpf/vmtest-bpf-next-VM_Test-35 success Logs for x86_64-llvm-18 / build / build for x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-36 success Logs for x86_64-llvm-18 / build-release / build for x86_64 with llvm-18-O2
bpf/vmtest-bpf-next-VM_Test-37 success Logs for x86_64-llvm-18 / test (test_maps, false, 360) / test_maps on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-41 success Logs for x86_64-llvm-18 / test (test_verifier, false, 360) / test_verifier on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-42 success Logs for x86_64-llvm-18 / veristat
bpf/vmtest-bpf-next-VM_Test-15 success Logs for s390x-gcc / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-14 success Logs for s390x-gcc / test (test_progs, false, 360) / test_progs on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-8 success Logs for aarch64-gcc / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-7 success Logs for aarch64-gcc / test (test_progs, false, 360) / test_progs on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-22 success Logs for x86_64-gcc / test (test_progs, false, 360) / test_progs on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-23 success Logs for x86_64-gcc / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-38 success Logs for x86_64-llvm-18 / test (test_progs, false, 360) / test_progs on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-39 success Logs for x86_64-llvm-18 / test (test_progs_cpuv4, false, 360) / test_progs_cpuv4 on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-40 success Logs for x86_64-llvm-18 / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on x86_64 with llvm-18
netdev/series_format success Posting correctly formatted
netdev/tree_selection success Clearly marked for bpf-next, async
netdev/ynl success Generated files up to date; no warnings/errors; no diff in generated;
netdev/fixes_present success Fixes tag not required for -next series
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 272 this patch: 272
netdev/build_tools success Errors and warnings before: 10 this patch: 10
netdev/cc_maintainers warning 13 maintainers not CCed: kpsingh@kernel.org haoluo@google.com edumazet@google.com kuba@kernel.org daniel@iogearbox.net john.fastabend@gmail.com jolsa@kernel.org yonghong.song@linux.dev eddyz87@gmail.com pabeni@redhat.com sdf@fomichev.me netdev@vger.kernel.org johannes.berg@intel.com
netdev/build_clang success Errors and warnings before: 340 this patch: 340
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/deprecated_api success None detected
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success No Fixes tag
netdev/build_allmodconfig_warn success Errors and warnings before: 7032 this patch: 7032
netdev/checkpatch warning CHECK: multiple assignments should be avoided WARNING: line length of 81 exceeds 80 columns WARNING: line length of 86 exceeds 80 columns WARNING: line length of 87 exceeds 80 columns WARNING: line length of 88 exceeds 80 columns WARNING: line length of 89 exceeds 80 columns WARNING: line length of 94 exceeds 80 columns WARNING: line length of 97 exceeds 80 columns
netdev/build_clang_rust success No Rust files in patch. Skipping build
netdev/kdoc success Errors and warnings before: 15 this patch: 15
netdev/source_inline success Was 0 now: 0

Commit Message

Kui-Feng Lee Aug. 7, 2024, 11:57 p.m. UTC
User kptrs are pinned, by pin_user_pages_fast(), and translated to an
address in the kernel when the value is updated by user programs. (Call
bpf_map_update_elem() from user programs.) And, the pinned pages are
unpinned if the value of user kptrs are overritten or if the values of maps
are deleted/destroyed.

The pages are mapped through vmap() in order to get a continuous space in
the kernel if the memory pointed by a user kptr resides in two or more
pages. For the case of single page, page_address() is called to get the
address of a page in the kernel.

User kptr is only supported by task storage maps.

One user kptr can pin at most KPTR_USER_MAX_PAGES(16) physical pages. This
is a random picked number for safety. We actually can remove this
restriction totally.

User kptrs could only be set by user programs through syscalls.  Any
attempts of updating the value of a map with __kptr_user in it should
ignore the values of user kptrs from BPF programs. The values of user kptrs
will keep as they were if the new values are from BPF programs, not from
user programs.

Cc: linux-mm@kvack.org
Signed-off-by: Kui-Feng Lee <thinker.li@gmail.com>
---
 include/linux/bpf.h               |  35 +++++-
 include/linux/bpf_local_storage.h |   2 +-
 kernel/bpf/bpf_local_storage.c    |  18 +--
 kernel/bpf/helpers.c              |  12 +-
 kernel/bpf/local_storage.c        |   2 +-
 kernel/bpf/syscall.c              | 177 +++++++++++++++++++++++++++++-
 net/core/bpf_sk_storage.c         |   2 +-
 7 files changed, 227 insertions(+), 21 deletions(-)

Comments

Kui-Feng Lee Aug. 8, 2024, 12:05 a.m. UTC | #1
On 8/7/24 16:57, Kui-Feng Lee wrote:
> User kptrs are pinned, by pin_user_pages_fast(), and translated to an
> address in the kernel when the value is updated by user programs. (Call
> bpf_map_update_elem() from user programs.) And, the pinned pages are
> unpinned if the value of user kptrs are overritten or if the values of maps
> are deleted/destroyed.
> 
> The pages are mapped through vmap() in order to get a continuous space in
> the kernel if the memory pointed by a user kptr resides in two or more
> pages. For the case of single page, page_address() is called to get the
> address of a page in the kernel.
> 
> User kptr is only supported by task storage maps.
> 
> One user kptr can pin at most KPTR_USER_MAX_PAGES(16) physical pages. This
> is a random picked number for safety. We actually can remove this
> restriction totally.
> 
> User kptrs could only be set by user programs through syscalls.  Any
> attempts of updating the value of a map with __kptr_user in it should
> ignore the values of user kptrs from BPF programs. The values of user kptrs
> will keep as they were if the new values are from BPF programs, not from
> user programs.
> 
> Cc: linux-mm@kvack.org
> Signed-off-by: Kui-Feng Lee <thinker.li@gmail.com>
> ---
>   include/linux/bpf.h               |  35 +++++-
>   include/linux/bpf_local_storage.h |   2 +-
>   kernel/bpf/bpf_local_storage.c    |  18 +--
>   kernel/bpf/helpers.c              |  12 +-
>   kernel/bpf/local_storage.c        |   2 +-
>   kernel/bpf/syscall.c              | 177 +++++++++++++++++++++++++++++-
>   net/core/bpf_sk_storage.c         |   2 +-
>   7 files changed, 227 insertions(+), 21 deletions(-)
> 
> diff --git a/include/linux/bpf.h b/include/linux/bpf.h
> index 87d5f98249e2..f4ad0bc183cb 100644
> --- a/include/linux/bpf.h
> +++ b/include/linux/bpf.h
> @@ -30,6 +30,7 @@
>   #include <linux/static_call.h>
>   #include <linux/memcontrol.h>
>   #include <linux/cfi.h>
> +#include <linux/mm.h>
>   
>   struct bpf_verifier_env;
>   struct bpf_verifier_log;
> @@ -477,10 +478,12 @@ static inline void bpf_long_memcpy(void *dst, const void *src, u32 size)
>   		data_race(*ldst++ = *lsrc++);
>   }
>   
> +void bpf_obj_unpin_uaddr(const struct btf_field *field, void *addr);
> +
>   /* copy everything but bpf_spin_lock, bpf_timer, and kptrs. There could be one of each. */
>   static inline void bpf_obj_memcpy(struct btf_record *rec,
>   				  void *dst, void *src, u32 size,
> -				  bool long_memcpy)
> +				  bool long_memcpy, bool from_user)
>   {
>   	u32 curr_off = 0;
>   	int i;
> @@ -496,21 +499,40 @@ static inline void bpf_obj_memcpy(struct btf_record *rec,
>   	for (i = 0; i < rec->cnt; i++) {
>   		u32 next_off = rec->fields[i].offset;
>   		u32 sz = next_off - curr_off;
> +		void *addr;
>   
>   		memcpy(dst + curr_off, src + curr_off, sz);
> +		if (from_user && rec->fields[i].type == BPF_KPTR_USER) {
> +			/* Unpin old address.
> +			 *
> +			 * Alignments are guaranteed by btf_find_field_one().
> +			 */
> +			addr = *(void **)(dst + next_off);
> +			if (virt_addr_valid(addr))
> +				bpf_obj_unpin_uaddr(&rec->fields[i], addr);
> +			else if (addr)
> +				WARN_ON_ONCE(1);
> +
> +			*(void **)(dst + next_off) = *(void **)(src + next_off);
> +		}
>   		curr_off += rec->fields[i].size + sz;
>   	}
>   	memcpy(dst + curr_off, src + curr_off, size - curr_off);
>   }
>   
> +static inline void copy_map_value_user(struct bpf_map *map, void *dst, void *src, bool from_user)
> +{
> +	bpf_obj_memcpy(map->record, dst, src, map->value_size, false, from_user);
> +}
> +
>   static inline void copy_map_value(struct bpf_map *map, void *dst, void *src)
>   {
> -	bpf_obj_memcpy(map->record, dst, src, map->value_size, false);
> +	bpf_obj_memcpy(map->record, dst, src, map->value_size, false, false);
>   }
>   
>   static inline void copy_map_value_long(struct bpf_map *map, void *dst, void *src)
>   {
> -	bpf_obj_memcpy(map->record, dst, src, map->value_size, true);
> +	bpf_obj_memcpy(map->record, dst, src, map->value_size, true, false);
>   }
>   
>   static inline void bpf_obj_memzero(struct btf_record *rec, void *dst, u32 size)
> @@ -538,6 +560,8 @@ static inline void zero_map_value(struct bpf_map *map, void *dst)
>   	bpf_obj_memzero(map->record, dst, map->value_size);
>   }
>   
> +void copy_map_value_locked_user(struct bpf_map *map, void *dst, void *src,
> +				bool lock_src, bool from_user);
>   void copy_map_value_locked(struct bpf_map *map, void *dst, void *src,
>   			   bool lock_src);
>   void bpf_timer_cancel_and_free(void *timer);
> @@ -775,6 +799,11 @@ enum bpf_arg_type {
>   };
>   static_assert(__BPF_ARG_TYPE_MAX <= BPF_BASE_TYPE_LIMIT);
>   
> +#define BPF_MAP_UPDATE_FLAG_BITS 3
> +enum bpf_map_update_flag {
> +	BPF_FROM_USER = BIT(0 + BPF_MAP_UPDATE_FLAG_BITS)
> +};
> +
>   /* type of values returned from helper functions */
>   enum bpf_return_type {
>   	RET_INTEGER,			/* function returns integer */
> diff --git a/include/linux/bpf_local_storage.h b/include/linux/bpf_local_storage.h
> index dcddb0aef7d8..d337df68fa23 100644
> --- a/include/linux/bpf_local_storage.h
> +++ b/include/linux/bpf_local_storage.h
> @@ -181,7 +181,7 @@ void bpf_selem_link_map(struct bpf_local_storage_map *smap,
>   
>   struct bpf_local_storage_elem *
>   bpf_selem_alloc(struct bpf_local_storage_map *smap, void *owner, void *value,
> -		bool charge_mem, gfp_t gfp_flags);
> +		bool charge_mem, gfp_t gfp_flags, bool from_user);
>   
>   void bpf_selem_free(struct bpf_local_storage_elem *selem,
>   		    struct bpf_local_storage_map *smap,
> diff --git a/kernel/bpf/bpf_local_storage.c b/kernel/bpf/bpf_local_storage.c
> index c938dea5ddbf..c4cf09e27a19 100644
> --- a/kernel/bpf/bpf_local_storage.c
> +++ b/kernel/bpf/bpf_local_storage.c
> @@ -73,7 +73,7 @@ static bool selem_linked_to_map(const struct bpf_local_storage_elem *selem)
>   
>   struct bpf_local_storage_elem *
>   bpf_selem_alloc(struct bpf_local_storage_map *smap, void *owner,
> -		void *value, bool charge_mem, gfp_t gfp_flags)
> +		void *value, bool charge_mem, gfp_t gfp_flags, bool from_user)
>   {
>   	struct bpf_local_storage_elem *selem;
>   
> @@ -100,7 +100,7 @@ bpf_selem_alloc(struct bpf_local_storage_map *smap, void *owner,
>   
>   	if (selem) {
>   		if (value)
> -			copy_map_value(&smap->map, SDATA(selem)->data, value);
> +			copy_map_value_user(&smap->map, SDATA(selem)->data, value, from_user);
>   		/* No need to call check_and_init_map_value as memory is zero init */
>   		return selem;
>   	}
> @@ -530,9 +530,11 @@ bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
>   	struct bpf_local_storage_elem *alloc_selem, *selem = NULL;
>   	struct bpf_local_storage *local_storage;
>   	unsigned long flags;
> +	bool from_user = map_flags & BPF_FROM_USER;
>   	int err;
>   
>   	/* BPF_EXIST and BPF_NOEXIST cannot be both set */
> +	map_flags &= ~BPF_FROM_USER;
>   	if (unlikely((map_flags & ~BPF_F_LOCK) > BPF_EXIST) ||
>   	    /* BPF_F_LOCK can only be used in a value with spin_lock */
>   	    unlikely((map_flags & BPF_F_LOCK) &&
> @@ -550,7 +552,7 @@ bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
>   		if (err)
>   			return ERR_PTR(err);
>   
> -		selem = bpf_selem_alloc(smap, owner, value, true, gfp_flags);
> +		selem = bpf_selem_alloc(smap, owner, value, true, gfp_flags, from_user);
>   		if (!selem)
>   			return ERR_PTR(-ENOMEM);
>   
> @@ -575,8 +577,8 @@ bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
>   		if (err)
>   			return ERR_PTR(err);
>   		if (old_sdata && selem_linked_to_storage_lockless(SELEM(old_sdata))) {
> -			copy_map_value_locked(&smap->map, old_sdata->data,
> -					      value, false);
> +			copy_map_value_locked_user(&smap->map, old_sdata->data,
> +						   value, false, from_user);
>   			return old_sdata;
>   		}
>   	}
> @@ -584,7 +586,7 @@ bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
>   	/* A lookup has just been done before and concluded a new selem is
>   	 * needed. The chance of an unnecessary alloc is unlikely.
>   	 */
> -	alloc_selem = selem = bpf_selem_alloc(smap, owner, value, true, gfp_flags);
> +	alloc_selem = selem = bpf_selem_alloc(smap, owner, value, true, gfp_flags, from_user);
>   	if (!alloc_selem)
>   		return ERR_PTR(-ENOMEM);
>   
> @@ -607,8 +609,8 @@ bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
>   		goto unlock;
>   
>   	if (old_sdata && (map_flags & BPF_F_LOCK)) {
> -		copy_map_value_locked(&smap->map, old_sdata->data, value,
> -				      false);
> +		copy_map_value_locked_user(&smap->map, old_sdata->data, value,
> +					   false, from_user);
>   		selem = SELEM(old_sdata);
>   		goto unlock;
>   	}
> diff --git a/kernel/bpf/helpers.c b/kernel/bpf/helpers.c
> index d02ae323996b..4aef86209fdd 100644
> --- a/kernel/bpf/helpers.c
> +++ b/kernel/bpf/helpers.c
> @@ -372,8 +372,8 @@ const struct bpf_func_proto bpf_spin_unlock_proto = {
>   	.arg1_btf_id    = BPF_PTR_POISON,
>   };
>   
> -void copy_map_value_locked(struct bpf_map *map, void *dst, void *src,
> -			   bool lock_src)
> +void copy_map_value_locked_user(struct bpf_map *map, void *dst, void *src,
> +				bool lock_src, bool from_user)
>   {
>   	struct bpf_spin_lock *lock;
>   
> @@ -383,11 +383,17 @@ void copy_map_value_locked(struct bpf_map *map, void *dst, void *src,
>   		lock = dst + map->record->spin_lock_off;
>   	preempt_disable();
>   	__bpf_spin_lock_irqsave(lock);
> -	copy_map_value(map, dst, src);
> +	copy_map_value_user(map, dst, src, from_user);
>   	__bpf_spin_unlock_irqrestore(lock);
>   	preempt_enable();
>   }
>   
> +void copy_map_value_locked(struct bpf_map *map, void *dst, void *src,
> +			   bool lock_src)
> +{
> +	copy_map_value_locked_user(map, dst, src, lock_src, false);
> +}
> +
>   BPF_CALL_0(bpf_jiffies64)
>   {
>   	return get_jiffies_64();
> diff --git a/kernel/bpf/local_storage.c b/kernel/bpf/local_storage.c
> index 3969eb0382af..62a12fa8ce9e 100644
> --- a/kernel/bpf/local_storage.c
> +++ b/kernel/bpf/local_storage.c
> @@ -147,7 +147,7 @@ static long cgroup_storage_update_elem(struct bpf_map *map, void *key,
>   	struct bpf_cgroup_storage *storage;
>   	struct bpf_storage_buffer *new;
>   
> -	if (unlikely(flags & ~(BPF_F_LOCK | BPF_EXIST)))
> +	if (unlikely(flags & ~BPF_F_LOCK))

This is a unnecessary change.
Will be removed.

>   		return -EINVAL;
>   
>   	if (unlikely((flags & BPF_F_LOCK) &&
> diff --git a/kernel/bpf/syscall.c b/kernel/bpf/syscall.c
> index 90a25307480e..eaa2a9d13265 100644
> --- a/kernel/bpf/syscall.c
> +++ b/kernel/bpf/syscall.c
> @@ -155,8 +155,134 @@ static void maybe_wait_bpf_programs(struct bpf_map *map)
>   		synchronize_rcu();
>   }
>   
> -static int bpf_map_update_value(struct bpf_map *map, struct file *map_file,
> -				void *key, void *value, __u64 flags)
> +static void *trans_addr_pages(struct page **pages, int npages)
> +{
> +	if (npages == 1)
> +		return page_address(pages[0]);
> +	/* For multiple pages, we need to use vmap() to get a contiguous
> +	 * virtual address range.
> +	 */
> +	return vmap(pages, npages, VM_MAP, PAGE_KERNEL);
> +}
> +
> +#define KPTR_USER_MAX_PAGES 16
> +
> +static int bpf_obj_trans_pin_uaddr(struct btf_field *field, void **addr)
> +{
> +	const struct btf_type *t;
> +	struct page *pages[KPTR_USER_MAX_PAGES];
> +	void *ptr, *kern_addr;
> +	u32 type_id, tsz;
> +	int r, npages;
> +
> +	ptr = *addr;
> +	type_id = field->kptr.btf_id;
> +	t = btf_type_id_size(field->kptr.btf, &type_id, &tsz);
> +	if (!t)
> +		return -EINVAL;
> +	if (tsz == 0) {
> +		*addr = NULL;
> +		return 0;
> +	}
> +
> +	npages = (((intptr_t)ptr + tsz + ~PAGE_MASK) -
> +		  ((intptr_t)ptr & PAGE_MASK)) >> PAGE_SHIFT;
> +	if (npages > KPTR_USER_MAX_PAGES)
> +		return -E2BIG;
> +	r = pin_user_pages_fast((intptr_t)ptr & PAGE_MASK, npages, 0, pages);
> +	if (r != npages)
> +		return -EINVAL;
> +	kern_addr = trans_addr_pages(pages, npages);
> +	if (!kern_addr)
> +		return -ENOMEM;
> +	*addr = kern_addr + ((intptr_t)ptr & ~PAGE_MASK);
> +	return 0;
> +}
> +
> +void bpf_obj_unpin_uaddr(const struct btf_field *field, void *addr)
> +{
> +	struct page *pages[KPTR_USER_MAX_PAGES];
> +	int npages, i;
> +	u32 size, type_id;
> +	void *ptr;
> +
> +	type_id = field->kptr.btf_id;
> +	btf_type_id_size(field->kptr.btf, &type_id, &size);
> +	if (size == 0)
> +		return;
> +
> +	ptr = (void *)((intptr_t)addr & PAGE_MASK);
> +	npages = (((intptr_t)addr + size + ~PAGE_MASK) - (intptr_t)ptr) >> PAGE_SHIFT;
> +	for (i = 0; i < npages; i++) {
> +		pages[i] = virt_to_page(ptr);
> +		ptr += PAGE_SIZE;
> +	}
> +	if (npages > 1)
> +		/* Paired with vmap() in trans_addr_pages() */
> +		vunmap((void *)((intptr_t)addr & PAGE_MASK));
> +	unpin_user_pages(pages, npages);
> +}
> +
> +static int bpf_obj_trans_pin_uaddrs(struct btf_record *rec, void *src, u32 size)
> +{
> +	u32 next_off;
> +	int i, err;
> +
> +	if (IS_ERR_OR_NULL(rec))
> +		return 0;
> +
> +	if (!btf_record_has_field(rec, BPF_KPTR_USER))
> +		return 0;
> +
> +	for (i = 0; i < rec->cnt; i++) {
> +		if (rec->fields[i].type != BPF_KPTR_USER)
> +			continue;
> +
> +		next_off = rec->fields[i].offset;
> +		if (next_off + sizeof(void *) > size)
> +			return -EINVAL;
> +		err = bpf_obj_trans_pin_uaddr(&rec->fields[i], src + next_off);
> +		if (!err)
> +			continue;
> +
> +		/* Rollback */
> +		for (i--; i >= 0; i--) {
> +			if (rec->fields[i].type != BPF_KPTR_USER)
> +				continue;
> +			next_off = rec->fields[i].offset;
> +			bpf_obj_unpin_uaddr(&rec->fields[i], *(void **)(src + next_off));
> +			*(void **)(src + next_off) = NULL;
> +		}
> +
> +		return err;
> +	}
> +
> +	return 0;
> +}
> +
> +static void bpf_obj_unpin_uaddrs(struct btf_record *rec, void *src)
> +{
> +	u32 next_off;
> +	int i;
> +
> +	if (IS_ERR_OR_NULL(rec))
> +		return;
> +
> +	if (!btf_record_has_field(rec, BPF_KPTR_USER))
> +		return;
> +
> +	for (i = 0; i < rec->cnt; i++) {
> +		if (rec->fields[i].type != BPF_KPTR_USER)
> +			continue;
> +
> +		next_off = rec->fields[i].offset;
> +		bpf_obj_unpin_uaddr(&rec->fields[i], *(void **)(src + next_off));
> +		*(void **)(src + next_off) = NULL;
> +	}
> +}
> +
> +static int bpf_map_update_value_inner(struct bpf_map *map, struct file *map_file,
> +				      void *key, void *value, __u64 flags)
>   {
>   	int err;
>   
> @@ -208,6 +334,29 @@ static int bpf_map_update_value(struct bpf_map *map, struct file *map_file,
>   	return err;
>   }
>   
> +static int bpf_map_update_value(struct bpf_map *map, struct file *map_file,
> +				void *key, void *value, __u64 flags)
> +{
> +	int err;
> +
> +	if (flags & BPF_FROM_USER) {
> +		/* Pin user memory can lead to context switch, so we need
> +		 * to do it before potential RCU lock.
> +		 */
> +		err = bpf_obj_trans_pin_uaddrs(map->record, value,
> +					       bpf_map_value_size(map));
> +		if (err)
> +			return err;
> +	}
> +
> +	err = bpf_map_update_value_inner(map, map_file, key, value, flags);
> +
> +	if (err && (flags & BPF_FROM_USER))
> +		bpf_obj_unpin_uaddrs(map->record, value);
> +
> +	return err;
> +}
> +
>   static int bpf_map_copy_value(struct bpf_map *map, void *key, void *value,
>   			      __u64 flags)
>   {
> @@ -714,6 +863,11 @@ void bpf_obj_free_fields(const struct btf_record *rec, void *obj)
>   				field->kptr.dtor(xchgd_field);
>   			}
>   			break;
> +		case BPF_KPTR_USER:
> +			if (virt_addr_valid(*(void **)field_ptr))
> +				bpf_obj_unpin_uaddr(field, *(void **)field_ptr);
> +			*(void **)field_ptr = NULL;
> +			break;
>   		case BPF_LIST_HEAD:
>   			if (WARN_ON_ONCE(rec->spin_lock_off < 0))
>   				continue;
> @@ -1155,6 +1309,12 @@ static int map_check_btf(struct bpf_map *map, struct bpf_token *token,
>   					goto free_map_tab;
>   				}
>   				break;
> +			case BPF_KPTR_USER:
> +				if (map->map_type != BPF_MAP_TYPE_TASK_STORAGE) {
> +					ret = -EOPNOTSUPP;
> +					goto free_map_tab;
> +				}
> +				break;
>   			case BPF_LIST_HEAD:
>   			case BPF_RB_ROOT:
>   				if (map->map_type != BPF_MAP_TYPE_HASH &&
> @@ -1618,11 +1778,15 @@ static int map_update_elem(union bpf_attr *attr, bpfptr_t uattr)
>   	struct bpf_map *map;
>   	void *key, *value;
>   	u32 value_size;
> +	u64 extra_flags = 0;
>   	struct fd f;
>   	int err;
>   
>   	if (CHECK_ATTR(BPF_MAP_UPDATE_ELEM))
>   		return -EINVAL;
> +	/* Prevent userspace from setting any internal flags */
> +	if (attr->flags & ~(BIT(BPF_MAP_UPDATE_FLAG_BITS) - 1))
> +		return -EINVAL;
>   
>   	f = fdget(ufd);
>   	map = __bpf_map_get(f);
> @@ -1653,7 +1817,9 @@ static int map_update_elem(union bpf_attr *attr, bpfptr_t uattr)
>   		goto free_key;
>   	}
>   
> -	err = bpf_map_update_value(map, f.file, key, value, attr->flags);
> +	if (map->map_type == BPF_MAP_TYPE_TASK_STORAGE)
> +		extra_flags |= BPF_FROM_USER;
> +	err = bpf_map_update_value(map, f.file, key, value, attr->flags | extra_flags);
>   	if (!err)
>   		maybe_wait_bpf_programs(map);
>   
> @@ -1852,6 +2018,7 @@ int generic_map_update_batch(struct bpf_map *map, struct file *map_file,
>   	void __user *keys = u64_to_user_ptr(attr->batch.keys);
>   	u32 value_size, cp, max_count;
>   	void *key, *value;
> +	u64 extra_flags = 0;
>   	int err = 0;
>   
>   	if (attr->batch.elem_flags & ~BPF_F_LOCK)
> @@ -1881,6 +2048,8 @@ int generic_map_update_batch(struct bpf_map *map, struct file *map_file,
>   		return -ENOMEM;
>   	}
>   
> +	if (map->map_type == BPF_MAP_TYPE_TASK_STORAGE)
> +		extra_flags |= BPF_FROM_USER;
>   	for (cp = 0; cp < max_count; cp++) {
>   		err = -EFAULT;
>   		if (copy_from_user(key, keys + cp * map->key_size,
> @@ -1889,7 +2058,7 @@ int generic_map_update_batch(struct bpf_map *map, struct file *map_file,
>   			break;
>   
>   		err = bpf_map_update_value(map, map_file, key, value,
> -					   attr->batch.elem_flags);
> +					   attr->batch.elem_flags | extra_flags);
>   
>   		if (err)
>   			break;
> diff --git a/net/core/bpf_sk_storage.c b/net/core/bpf_sk_storage.c
> index bc01b3aa6b0f..db5281384e6a 100644
> --- a/net/core/bpf_sk_storage.c
> +++ b/net/core/bpf_sk_storage.c
> @@ -137,7 +137,7 @@ bpf_sk_storage_clone_elem(struct sock *newsk,
>   {
>   	struct bpf_local_storage_elem *copy_selem;
>   
> -	copy_selem = bpf_selem_alloc(smap, newsk, NULL, true, GFP_ATOMIC);
> +	copy_selem = bpf_selem_alloc(smap, newsk, NULL, true, GFP_ATOMIC, false);
>   	if (!copy_selem)
>   		return NULL;
>
Kui-Feng Lee Aug. 8, 2024, 12:39 a.m. UTC | #2
bpf_obj_trans_pin_uaddrs() is where pinning and mapping performed. It is
called when a syscall is called to update the value of a map. This
function will rewrite the value of user kptrs to the addresses in the
kernel.


On 8/7/24 16:57, Kui-Feng Lee wrote:
> User kptrs are pinned, by pin_user_pages_fast(), and translated to an
> address in the kernel when the value is updated by user programs. (Call
> bpf_map_update_elem() from user programs.) And, the pinned pages are
> unpinned if the value of user kptrs are overritten or if the values of maps
> are deleted/destroyed.
> 
> The pages are mapped through vmap() in order to get a continuous space in
> the kernel if the memory pointed by a user kptr resides in two or more
> pages. For the case of single page, page_address() is called to get the
> address of a page in the kernel.
> 
> User kptr is only supported by task storage maps.
> 
> One user kptr can pin at most KPTR_USER_MAX_PAGES(16) physical pages. This
> is a random picked number for safety. We actually can remove this
> restriction totally.
> 
> User kptrs could only be set by user programs through syscalls.  Any
> attempts of updating the value of a map with __kptr_user in it should
> ignore the values of user kptrs from BPF programs. The values of user kptrs
> will keep as they were if the new values are from BPF programs, not from
> user programs.
> 
> Cc: linux-mm@kvack.org
> Signed-off-by: Kui-Feng Lee <thinker.li@gmail.com>
> ---
>   include/linux/bpf.h               |  35 +++++-
>   include/linux/bpf_local_storage.h |   2 +-
>   kernel/bpf/bpf_local_storage.c    |  18 +--
>   kernel/bpf/helpers.c              |  12 +-
>   kernel/bpf/local_storage.c        |   2 +-
>   kernel/bpf/syscall.c              | 177 +++++++++++++++++++++++++++++-
>   net/core/bpf_sk_storage.c         |   2 +-
>   7 files changed, 227 insertions(+), 21 deletions(-)
> 
> diff --git a/include/linux/bpf.h b/include/linux/bpf.h
> index 87d5f98249e2..f4ad0bc183cb 100644
> --- a/include/linux/bpf.h
> +++ b/include/linux/bpf.h
> @@ -30,6 +30,7 @@
>   #include <linux/static_call.h>
>   #include <linux/memcontrol.h>
>   #include <linux/cfi.h>
> +#include <linux/mm.h>
>   
>   struct bpf_verifier_env;
>   struct bpf_verifier_log;
> @@ -477,10 +478,12 @@ static inline void bpf_long_memcpy(void *dst, const void *src, u32 size)
>   		data_race(*ldst++ = *lsrc++);
>   }
>   
> +void bpf_obj_unpin_uaddr(const struct btf_field *field, void *addr);
> +
>   /* copy everything but bpf_spin_lock, bpf_timer, and kptrs. There could be one of each. */
>   static inline void bpf_obj_memcpy(struct btf_record *rec,
>   				  void *dst, void *src, u32 size,
> -				  bool long_memcpy)
> +				  bool long_memcpy, bool from_user)
>   {
>   	u32 curr_off = 0;
>   	int i;
> @@ -496,21 +499,40 @@ static inline void bpf_obj_memcpy(struct btf_record *rec,
>   	for (i = 0; i < rec->cnt; i++) {
>   		u32 next_off = rec->fields[i].offset;
>   		u32 sz = next_off - curr_off;
> +		void *addr;
>   
>   		memcpy(dst + curr_off, src + curr_off, sz);
> +		if (from_user && rec->fields[i].type == BPF_KPTR_USER) {
> +			/* Unpin old address.
> +			 *
> +			 * Alignments are guaranteed by btf_find_field_one().
> +			 */
> +			addr = *(void **)(dst + next_off);
> +			if (virt_addr_valid(addr))
> +				bpf_obj_unpin_uaddr(&rec->fields[i], addr);
> +			else if (addr)
> +				WARN_ON_ONCE(1);
> +
> +			*(void **)(dst + next_off) = *(void **)(src + next_off);
> +		}
>   		curr_off += rec->fields[i].size + sz;
>   	}
>   	memcpy(dst + curr_off, src + curr_off, size - curr_off);
>   }
>   
> +static inline void copy_map_value_user(struct bpf_map *map, void *dst, void *src, bool from_user)
> +{
> +	bpf_obj_memcpy(map->record, dst, src, map->value_size, false, from_user);
> +}
> +
>   static inline void copy_map_value(struct bpf_map *map, void *dst, void *src)
>   {
> -	bpf_obj_memcpy(map->record, dst, src, map->value_size, false);
> +	bpf_obj_memcpy(map->record, dst, src, map->value_size, false, false);
>   }
>   
>   static inline void copy_map_value_long(struct bpf_map *map, void *dst, void *src)
>   {
> -	bpf_obj_memcpy(map->record, dst, src, map->value_size, true);
> +	bpf_obj_memcpy(map->record, dst, src, map->value_size, true, false);
>   }
>   
>   static inline void bpf_obj_memzero(struct btf_record *rec, void *dst, u32 size)
> @@ -538,6 +560,8 @@ static inline void zero_map_value(struct bpf_map *map, void *dst)
>   	bpf_obj_memzero(map->record, dst, map->value_size);
>   }
>   
> +void copy_map_value_locked_user(struct bpf_map *map, void *dst, void *src,
> +				bool lock_src, bool from_user);
>   void copy_map_value_locked(struct bpf_map *map, void *dst, void *src,
>   			   bool lock_src);
>   void bpf_timer_cancel_and_free(void *timer);
> @@ -775,6 +799,11 @@ enum bpf_arg_type {
>   };
>   static_assert(__BPF_ARG_TYPE_MAX <= BPF_BASE_TYPE_LIMIT);
>   
> +#define BPF_MAP_UPDATE_FLAG_BITS 3
> +enum bpf_map_update_flag {
> +	BPF_FROM_USER = BIT(0 + BPF_MAP_UPDATE_FLAG_BITS)
> +};
> +
>   /* type of values returned from helper functions */
>   enum bpf_return_type {
>   	RET_INTEGER,			/* function returns integer */
> diff --git a/include/linux/bpf_local_storage.h b/include/linux/bpf_local_storage.h
> index dcddb0aef7d8..d337df68fa23 100644
> --- a/include/linux/bpf_local_storage.h
> +++ b/include/linux/bpf_local_storage.h
> @@ -181,7 +181,7 @@ void bpf_selem_link_map(struct bpf_local_storage_map *smap,
>   
>   struct bpf_local_storage_elem *
>   bpf_selem_alloc(struct bpf_local_storage_map *smap, void *owner, void *value,
> -		bool charge_mem, gfp_t gfp_flags);
> +		bool charge_mem, gfp_t gfp_flags, bool from_user);
>   
>   void bpf_selem_free(struct bpf_local_storage_elem *selem,
>   		    struct bpf_local_storage_map *smap,
> diff --git a/kernel/bpf/bpf_local_storage.c b/kernel/bpf/bpf_local_storage.c
> index c938dea5ddbf..c4cf09e27a19 100644
> --- a/kernel/bpf/bpf_local_storage.c
> +++ b/kernel/bpf/bpf_local_storage.c
> @@ -73,7 +73,7 @@ static bool selem_linked_to_map(const struct bpf_local_storage_elem *selem)
>   
>   struct bpf_local_storage_elem *
>   bpf_selem_alloc(struct bpf_local_storage_map *smap, void *owner,
> -		void *value, bool charge_mem, gfp_t gfp_flags)
> +		void *value, bool charge_mem, gfp_t gfp_flags, bool from_user)
>   {
>   	struct bpf_local_storage_elem *selem;
>   
> @@ -100,7 +100,7 @@ bpf_selem_alloc(struct bpf_local_storage_map *smap, void *owner,
>   
>   	if (selem) {
>   		if (value)
> -			copy_map_value(&smap->map, SDATA(selem)->data, value);
> +			copy_map_value_user(&smap->map, SDATA(selem)->data, value, from_user);
>   		/* No need to call check_and_init_map_value as memory is zero init */
>   		return selem;
>   	}
> @@ -530,9 +530,11 @@ bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
>   	struct bpf_local_storage_elem *alloc_selem, *selem = NULL;
>   	struct bpf_local_storage *local_storage;
>   	unsigned long flags;
> +	bool from_user = map_flags & BPF_FROM_USER;
>   	int err;
>   
>   	/* BPF_EXIST and BPF_NOEXIST cannot be both set */
> +	map_flags &= ~BPF_FROM_USER;
>   	if (unlikely((map_flags & ~BPF_F_LOCK) > BPF_EXIST) ||
>   	    /* BPF_F_LOCK can only be used in a value with spin_lock */
>   	    unlikely((map_flags & BPF_F_LOCK) &&
> @@ -550,7 +552,7 @@ bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
>   		if (err)
>   			return ERR_PTR(err);
>   
> -		selem = bpf_selem_alloc(smap, owner, value, true, gfp_flags);
> +		selem = bpf_selem_alloc(smap, owner, value, true, gfp_flags, from_user);
>   		if (!selem)
>   			return ERR_PTR(-ENOMEM);
>   
> @@ -575,8 +577,8 @@ bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
>   		if (err)
>   			return ERR_PTR(err);
>   		if (old_sdata && selem_linked_to_storage_lockless(SELEM(old_sdata))) {
> -			copy_map_value_locked(&smap->map, old_sdata->data,
> -					      value, false);
> +			copy_map_value_locked_user(&smap->map, old_sdata->data,
> +						   value, false, from_user);
>   			return old_sdata;
>   		}
>   	}
> @@ -584,7 +586,7 @@ bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
>   	/* A lookup has just been done before and concluded a new selem is
>   	 * needed. The chance of an unnecessary alloc is unlikely.
>   	 */
> -	alloc_selem = selem = bpf_selem_alloc(smap, owner, value, true, gfp_flags);
> +	alloc_selem = selem = bpf_selem_alloc(smap, owner, value, true, gfp_flags, from_user);
>   	if (!alloc_selem)
>   		return ERR_PTR(-ENOMEM);
>   
> @@ -607,8 +609,8 @@ bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
>   		goto unlock;
>   
>   	if (old_sdata && (map_flags & BPF_F_LOCK)) {
> -		copy_map_value_locked(&smap->map, old_sdata->data, value,
> -				      false);
> +		copy_map_value_locked_user(&smap->map, old_sdata->data, value,
> +					   false, from_user);
>   		selem = SELEM(old_sdata);
>   		goto unlock;
>   	}
> diff --git a/kernel/bpf/helpers.c b/kernel/bpf/helpers.c
> index d02ae323996b..4aef86209fdd 100644
> --- a/kernel/bpf/helpers.c
> +++ b/kernel/bpf/helpers.c
> @@ -372,8 +372,8 @@ const struct bpf_func_proto bpf_spin_unlock_proto = {
>   	.arg1_btf_id    = BPF_PTR_POISON,
>   };
>   
> -void copy_map_value_locked(struct bpf_map *map, void *dst, void *src,
> -			   bool lock_src)
> +void copy_map_value_locked_user(struct bpf_map *map, void *dst, void *src,
> +				bool lock_src, bool from_user)
>   {
>   	struct bpf_spin_lock *lock;
>   
> @@ -383,11 +383,17 @@ void copy_map_value_locked(struct bpf_map *map, void *dst, void *src,
>   		lock = dst + map->record->spin_lock_off;
>   	preempt_disable();
>   	__bpf_spin_lock_irqsave(lock);
> -	copy_map_value(map, dst, src);
> +	copy_map_value_user(map, dst, src, from_user);
>   	__bpf_spin_unlock_irqrestore(lock);
>   	preempt_enable();
>   }
>   
> +void copy_map_value_locked(struct bpf_map *map, void *dst, void *src,
> +			   bool lock_src)
> +{
> +	copy_map_value_locked_user(map, dst, src, lock_src, false);
> +}
> +
>   BPF_CALL_0(bpf_jiffies64)
>   {
>   	return get_jiffies_64();
> diff --git a/kernel/bpf/local_storage.c b/kernel/bpf/local_storage.c
> index 3969eb0382af..62a12fa8ce9e 100644
> --- a/kernel/bpf/local_storage.c
> +++ b/kernel/bpf/local_storage.c
> @@ -147,7 +147,7 @@ static long cgroup_storage_update_elem(struct bpf_map *map, void *key,
>   	struct bpf_cgroup_storage *storage;
>   	struct bpf_storage_buffer *new;
>   
> -	if (unlikely(flags & ~(BPF_F_LOCK | BPF_EXIST)))
> +	if (unlikely(flags & ~BPF_F_LOCK))
>   		return -EINVAL;
>   
>   	if (unlikely((flags & BPF_F_LOCK) &&
> diff --git a/kernel/bpf/syscall.c b/kernel/bpf/syscall.c
> index 90a25307480e..eaa2a9d13265 100644
> --- a/kernel/bpf/syscall.c
> +++ b/kernel/bpf/syscall.c
> @@ -155,8 +155,134 @@ static void maybe_wait_bpf_programs(struct bpf_map *map)
>   		synchronize_rcu();
>   }
>   
> -static int bpf_map_update_value(struct bpf_map *map, struct file *map_file,
> -				void *key, void *value, __u64 flags)
> +static void *trans_addr_pages(struct page **pages, int npages)
> +{
> +	if (npages == 1)
> +		return page_address(pages[0]);
> +	/* For multiple pages, we need to use vmap() to get a contiguous
> +	 * virtual address range.
> +	 */
> +	return vmap(pages, npages, VM_MAP, PAGE_KERNEL);
> +}
> +
> +#define KPTR_USER_MAX_PAGES 16
> +
> +static int bpf_obj_trans_pin_uaddr(struct btf_field *field, void **addr)
> +{
> +	const struct btf_type *t;
> +	struct page *pages[KPTR_USER_MAX_PAGES];
> +	void *ptr, *kern_addr;
> +	u32 type_id, tsz;
> +	int r, npages;
> +
> +	ptr = *addr;
> +	type_id = field->kptr.btf_id;
> +	t = btf_type_id_size(field->kptr.btf, &type_id, &tsz);
> +	if (!t)
> +		return -EINVAL;
> +	if (tsz == 0) {
> +		*addr = NULL;
> +		return 0;
> +	}
> +
> +	npages = (((intptr_t)ptr + tsz + ~PAGE_MASK) -
> +		  ((intptr_t)ptr & PAGE_MASK)) >> PAGE_SHIFT;
> +	if (npages > KPTR_USER_MAX_PAGES)
> +		return -E2BIG;
> +	r = pin_user_pages_fast((intptr_t)ptr & PAGE_MASK, npages, 0, pages);
> +	if (r != npages)
> +		return -EINVAL;
> +	kern_addr = trans_addr_pages(pages, npages);
> +	if (!kern_addr)
> +		return -ENOMEM;
> +	*addr = kern_addr + ((intptr_t)ptr & ~PAGE_MASK);
> +	return 0;
> +}
> +
> +void bpf_obj_unpin_uaddr(const struct btf_field *field, void *addr)
> +{
> +	struct page *pages[KPTR_USER_MAX_PAGES];
> +	int npages, i;
> +	u32 size, type_id;
> +	void *ptr;
> +
> +	type_id = field->kptr.btf_id;
> +	btf_type_id_size(field->kptr.btf, &type_id, &size);
> +	if (size == 0)
> +		return;
> +
> +	ptr = (void *)((intptr_t)addr & PAGE_MASK);
> +	npages = (((intptr_t)addr + size + ~PAGE_MASK) - (intptr_t)ptr) >> PAGE_SHIFT;
> +	for (i = 0; i < npages; i++) {
> +		pages[i] = virt_to_page(ptr);
> +		ptr += PAGE_SIZE;
> +	}
> +	if (npages > 1)
> +		/* Paired with vmap() in trans_addr_pages() */
> +		vunmap((void *)((intptr_t)addr & PAGE_MASK));
> +	unpin_user_pages(pages, npages);
> +}
> +
> +static int bpf_obj_trans_pin_uaddrs(struct btf_record *rec, void *src, u32 size)
> +{
> +	u32 next_off;
> +	int i, err;
> +
> +	if (IS_ERR_OR_NULL(rec))
> +		return 0;
> +
> +	if (!btf_record_has_field(rec, BPF_KPTR_USER))
> +		return 0;
> +
> +	for (i = 0; i < rec->cnt; i++) {
> +		if (rec->fields[i].type != BPF_KPTR_USER)
> +			continue;
> +
> +		next_off = rec->fields[i].offset;
> +		if (next_off + sizeof(void *) > size)
> +			return -EINVAL;
> +		err = bpf_obj_trans_pin_uaddr(&rec->fields[i], src + next_off);
> +		if (!err)
> +			continue;
> +
> +		/* Rollback */
> +		for (i--; i >= 0; i--) {
> +			if (rec->fields[i].type != BPF_KPTR_USER)
> +				continue;
> +			next_off = rec->fields[i].offset;
> +			bpf_obj_unpin_uaddr(&rec->fields[i], *(void **)(src + next_off));
> +			*(void **)(src + next_off) = NULL;
> +		}
> +
> +		return err;
> +	}
> +
> +	return 0;
> +}
> +
> +static void bpf_obj_unpin_uaddrs(struct btf_record *rec, void *src)
> +{
> +	u32 next_off;
> +	int i;
> +
> +	if (IS_ERR_OR_NULL(rec))
> +		return;
> +
> +	if (!btf_record_has_field(rec, BPF_KPTR_USER))
> +		return;
> +
> +	for (i = 0; i < rec->cnt; i++) {
> +		if (rec->fields[i].type != BPF_KPTR_USER)
> +			continue;
> +
> +		next_off = rec->fields[i].offset;
> +		bpf_obj_unpin_uaddr(&rec->fields[i], *(void **)(src + next_off));
> +		*(void **)(src + next_off) = NULL;
> +	}
> +}
> +
> +static int bpf_map_update_value_inner(struct bpf_map *map, struct file *map_file,
> +				      void *key, void *value, __u64 flags)
>   {
>   	int err;
>   
> @@ -208,6 +334,29 @@ static int bpf_map_update_value(struct bpf_map *map, struct file *map_file,
>   	return err;
>   }
>   
> +static int bpf_map_update_value(struct bpf_map *map, struct file *map_file,
> +				void *key, void *value, __u64 flags)
> +{
> +	int err;
> +
> +	if (flags & BPF_FROM_USER) {
> +		/* Pin user memory can lead to context switch, so we need
> +		 * to do it before potential RCU lock.
> +		 */
> +		err = bpf_obj_trans_pin_uaddrs(map->record, value,
> +					       bpf_map_value_size(map));
> +		if (err)
> +			return err;
> +	}
> +
> +	err = bpf_map_update_value_inner(map, map_file, key, value, flags);
> +
> +	if (err && (flags & BPF_FROM_USER))
> +		bpf_obj_unpin_uaddrs(map->record, value);
> +
> +	return err;
> +}
> +
>   static int bpf_map_copy_value(struct bpf_map *map, void *key, void *value,
>   			      __u64 flags)
>   {
> @@ -714,6 +863,11 @@ void bpf_obj_free_fields(const struct btf_record *rec, void *obj)
>   				field->kptr.dtor(xchgd_field);
>   			}
>   			break;
> +		case BPF_KPTR_USER:
> +			if (virt_addr_valid(*(void **)field_ptr))
> +				bpf_obj_unpin_uaddr(field, *(void **)field_ptr);
> +			*(void **)field_ptr = NULL;
> +			break;
>   		case BPF_LIST_HEAD:
>   			if (WARN_ON_ONCE(rec->spin_lock_off < 0))
>   				continue;
> @@ -1155,6 +1309,12 @@ static int map_check_btf(struct bpf_map *map, struct bpf_token *token,
>   					goto free_map_tab;
>   				}
>   				break;
> +			case BPF_KPTR_USER:
> +				if (map->map_type != BPF_MAP_TYPE_TASK_STORAGE) {
> +					ret = -EOPNOTSUPP;
> +					goto free_map_tab;
> +				}
> +				break;
>   			case BPF_LIST_HEAD:
>   			case BPF_RB_ROOT:
>   				if (map->map_type != BPF_MAP_TYPE_HASH &&
> @@ -1618,11 +1778,15 @@ static int map_update_elem(union bpf_attr *attr, bpfptr_t uattr)
>   	struct bpf_map *map;
>   	void *key, *value;
>   	u32 value_size;
> +	u64 extra_flags = 0;
>   	struct fd f;
>   	int err;
>   
>   	if (CHECK_ATTR(BPF_MAP_UPDATE_ELEM))
>   		return -EINVAL;
> +	/* Prevent userspace from setting any internal flags */
> +	if (attr->flags & ~(BIT(BPF_MAP_UPDATE_FLAG_BITS) - 1))
> +		return -EINVAL;
>   
>   	f = fdget(ufd);
>   	map = __bpf_map_get(f);
> @@ -1653,7 +1817,9 @@ static int map_update_elem(union bpf_attr *attr, bpfptr_t uattr)
>   		goto free_key;
>   	}
>   
> -	err = bpf_map_update_value(map, f.file, key, value, attr->flags);
> +	if (map->map_type == BPF_MAP_TYPE_TASK_STORAGE)
> +		extra_flags |= BPF_FROM_USER;
> +	err = bpf_map_update_value(map, f.file, key, value, attr->flags | extra_flags);
>   	if (!err)
>   		maybe_wait_bpf_programs(map);
>   
> @@ -1852,6 +2018,7 @@ int generic_map_update_batch(struct bpf_map *map, struct file *map_file,
>   	void __user *keys = u64_to_user_ptr(attr->batch.keys);
>   	u32 value_size, cp, max_count;
>   	void *key, *value;
> +	u64 extra_flags = 0;
>   	int err = 0;
>   
>   	if (attr->batch.elem_flags & ~BPF_F_LOCK)
> @@ -1881,6 +2048,8 @@ int generic_map_update_batch(struct bpf_map *map, struct file *map_file,
>   		return -ENOMEM;
>   	}
>   
> +	if (map->map_type == BPF_MAP_TYPE_TASK_STORAGE)
> +		extra_flags |= BPF_FROM_USER;
>   	for (cp = 0; cp < max_count; cp++) {
>   		err = -EFAULT;
>   		if (copy_from_user(key, keys + cp * map->key_size,
> @@ -1889,7 +2058,7 @@ int generic_map_update_batch(struct bpf_map *map, struct file *map_file,
>   			break;
>   
>   		err = bpf_map_update_value(map, map_file, key, value,
> -					   attr->batch.elem_flags);
> +					   attr->batch.elem_flags | extra_flags);
>   
>   		if (err)
>   			break;
> diff --git a/net/core/bpf_sk_storage.c b/net/core/bpf_sk_storage.c
> index bc01b3aa6b0f..db5281384e6a 100644
> --- a/net/core/bpf_sk_storage.c
> +++ b/net/core/bpf_sk_storage.c
> @@ -137,7 +137,7 @@ bpf_sk_storage_clone_elem(struct sock *newsk,
>   {
>   	struct bpf_local_storage_elem *copy_selem;
>   
> -	copy_selem = bpf_selem_alloc(smap, newsk, NULL, true, GFP_ATOMIC);
> +	copy_selem = bpf_selem_alloc(smap, newsk, NULL, true, GFP_ATOMIC, false);
>   	if (!copy_selem)
>   		return NULL;
>
Kui-Feng Lee Aug. 12, 2024, 4 p.m. UTC | #3
On 8/7/24 16:57, Kui-Feng Lee wrote:
> User kptrs are pinned, by pin_user_pages_fast(), and translated to an
> address in the kernel when the value is updated by user programs. (Call
> bpf_map_update_elem() from user programs.) And, the pinned pages are
> unpinned if the value of user kptrs are overritten or if the values of maps
> are deleted/destroyed.
> 
> The pages are mapped through vmap() in order to get a continuous space in
> the kernel if the memory pointed by a user kptr resides in two or more
> pages. For the case of single page, page_address() is called to get the
> address of a page in the kernel.
> 
> User kptr is only supported by task storage maps.
> 
> One user kptr can pin at most KPTR_USER_MAX_PAGES(16) physical pages. This
> is a random picked number for safety. We actually can remove this
> restriction totally.
> 
> User kptrs could only be set by user programs through syscalls.  Any
> attempts of updating the value of a map with __kptr_user in it should
> ignore the values of user kptrs from BPF programs. The values of user kptrs
> will keep as they were if the new values are from BPF programs, not from
> user programs.
> 
> Cc: linux-mm@kvack.org
> Signed-off-by: Kui-Feng Lee <thinker.li@gmail.com>
> ---
>   include/linux/bpf.h               |  35 +++++-
>   include/linux/bpf_local_storage.h |   2 +-
>   kernel/bpf/bpf_local_storage.c    |  18 +--
>   kernel/bpf/helpers.c              |  12 +-
>   kernel/bpf/local_storage.c        |   2 +-
>   kernel/bpf/syscall.c              | 177 +++++++++++++++++++++++++++++-
>   net/core/bpf_sk_storage.c         |   2 +-
>   7 files changed, 227 insertions(+), 21 deletions(-)
> 
> diff --git a/include/linux/bpf.h b/include/linux/bpf.h
> index 87d5f98249e2..f4ad0bc183cb 100644
> --- a/include/linux/bpf.h
> +++ b/include/linux/bpf.h
> @@ -30,6 +30,7 @@
>   #include <linux/static_call.h>
>   #include <linux/memcontrol.h>
>   #include <linux/cfi.h>
> +#include <linux/mm.h>
>   
>   struct bpf_verifier_env;
>   struct bpf_verifier_log;
> @@ -477,10 +478,12 @@ static inline void bpf_long_memcpy(void *dst, const void *src, u32 size)
>   		data_race(*ldst++ = *lsrc++);
>   }
>   
> +void bpf_obj_unpin_uaddr(const struct btf_field *field, void *addr);
> +
>   /* copy everything but bpf_spin_lock, bpf_timer, and kptrs. There could be one of each. */
>   static inline void bpf_obj_memcpy(struct btf_record *rec,
>   				  void *dst, void *src, u32 size,
> -				  bool long_memcpy)
> +				  bool long_memcpy, bool from_user)
>   {
>   	u32 curr_off = 0;
>   	int i;
> @@ -496,21 +499,40 @@ static inline void bpf_obj_memcpy(struct btf_record *rec,
>   	for (i = 0; i < rec->cnt; i++) {
>   		u32 next_off = rec->fields[i].offset;
>   		u32 sz = next_off - curr_off;
> +		void *addr;
>   
>   		memcpy(dst + curr_off, src + curr_off, sz);
> +		if (from_user && rec->fields[i].type == BPF_KPTR_USER) {
> +			/* Unpin old address.
> +			 *
> +			 * Alignments are guaranteed by btf_find_field_one().
> +			 */
> +			addr = *(void **)(dst + next_off);
> +			if (virt_addr_valid(addr))
> +				bpf_obj_unpin_uaddr(&rec->fields[i], addr);
> +			else if (addr)
> +				WARN_ON_ONCE(1);
> +
> +			*(void **)(dst + next_off) = *(void **)(src + next_off);
> +		}
>   		curr_off += rec->fields[i].size + sz;
>   	}
>   	memcpy(dst + curr_off, src + curr_off, size - curr_off);
>   }
>   
> +static inline void copy_map_value_user(struct bpf_map *map, void *dst, void *src, bool from_user)
> +{
> +	bpf_obj_memcpy(map->record, dst, src, map->value_size, false, from_user);
> +}
> +
>   static inline void copy_map_value(struct bpf_map *map, void *dst, void *src)
>   {
> -	bpf_obj_memcpy(map->record, dst, src, map->value_size, false);
> +	bpf_obj_memcpy(map->record, dst, src, map->value_size, false, false);
>   }
>   
>   static inline void copy_map_value_long(struct bpf_map *map, void *dst, void *src)
>   {
> -	bpf_obj_memcpy(map->record, dst, src, map->value_size, true);
> +	bpf_obj_memcpy(map->record, dst, src, map->value_size, true, false);
>   }
>   
>   static inline void bpf_obj_memzero(struct btf_record *rec, void *dst, u32 size)
> @@ -538,6 +560,8 @@ static inline void zero_map_value(struct bpf_map *map, void *dst)
>   	bpf_obj_memzero(map->record, dst, map->value_size);
>   }
>   
> +void copy_map_value_locked_user(struct bpf_map *map, void *dst, void *src,
> +				bool lock_src, bool from_user);
>   void copy_map_value_locked(struct bpf_map *map, void *dst, void *src,
>   			   bool lock_src);
>   void bpf_timer_cancel_and_free(void *timer);
> @@ -775,6 +799,11 @@ enum bpf_arg_type {
>   };
>   static_assert(__BPF_ARG_TYPE_MAX <= BPF_BASE_TYPE_LIMIT);
>   
> +#define BPF_MAP_UPDATE_FLAG_BITS 3
> +enum bpf_map_update_flag {
> +	BPF_FROM_USER = BIT(0 + BPF_MAP_UPDATE_FLAG_BITS)
> +};
> +
>   /* type of values returned from helper functions */
>   enum bpf_return_type {
>   	RET_INTEGER,			/* function returns integer */
> diff --git a/include/linux/bpf_local_storage.h b/include/linux/bpf_local_storage.h
> index dcddb0aef7d8..d337df68fa23 100644
> --- a/include/linux/bpf_local_storage.h
> +++ b/include/linux/bpf_local_storage.h
> @@ -181,7 +181,7 @@ void bpf_selem_link_map(struct bpf_local_storage_map *smap,
>   
>   struct bpf_local_storage_elem *
>   bpf_selem_alloc(struct bpf_local_storage_map *smap, void *owner, void *value,
> -		bool charge_mem, gfp_t gfp_flags);
> +		bool charge_mem, gfp_t gfp_flags, bool from_user);
>   
>   void bpf_selem_free(struct bpf_local_storage_elem *selem,
>   		    struct bpf_local_storage_map *smap,
> diff --git a/kernel/bpf/bpf_local_storage.c b/kernel/bpf/bpf_local_storage.c
> index c938dea5ddbf..c4cf09e27a19 100644
> --- a/kernel/bpf/bpf_local_storage.c
> +++ b/kernel/bpf/bpf_local_storage.c
> @@ -73,7 +73,7 @@ static bool selem_linked_to_map(const struct bpf_local_storage_elem *selem)
>   
>   struct bpf_local_storage_elem *
>   bpf_selem_alloc(struct bpf_local_storage_map *smap, void *owner,
> -		void *value, bool charge_mem, gfp_t gfp_flags)
> +		void *value, bool charge_mem, gfp_t gfp_flags, bool from_user)
>   {
>   	struct bpf_local_storage_elem *selem;
>   
> @@ -100,7 +100,7 @@ bpf_selem_alloc(struct bpf_local_storage_map *smap, void *owner,
>   
>   	if (selem) {
>   		if (value)
> -			copy_map_value(&smap->map, SDATA(selem)->data, value);
> +			copy_map_value_user(&smap->map, SDATA(selem)->data, value, from_user);
>   		/* No need to call check_and_init_map_value as memory is zero init */
>   		return selem;
>   	}
> @@ -530,9 +530,11 @@ bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
>   	struct bpf_local_storage_elem *alloc_selem, *selem = NULL;
>   	struct bpf_local_storage *local_storage;
>   	unsigned long flags;
> +	bool from_user = map_flags & BPF_FROM_USER;
>   	int err;
>   
>   	/* BPF_EXIST and BPF_NOEXIST cannot be both set */
> +	map_flags &= ~BPF_FROM_USER;
>   	if (unlikely((map_flags & ~BPF_F_LOCK) > BPF_EXIST) ||
>   	    /* BPF_F_LOCK can only be used in a value with spin_lock */
>   	    unlikely((map_flags & BPF_F_LOCK) &&
> @@ -550,7 +552,7 @@ bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
>   		if (err)
>   			return ERR_PTR(err);
>   
> -		selem = bpf_selem_alloc(smap, owner, value, true, gfp_flags);
> +		selem = bpf_selem_alloc(smap, owner, value, true, gfp_flags, from_user);
>   		if (!selem)
>   			return ERR_PTR(-ENOMEM);
>   
> @@ -575,8 +577,8 @@ bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
>   		if (err)
>   			return ERR_PTR(err);
>   		if (old_sdata && selem_linked_to_storage_lockless(SELEM(old_sdata))) {
> -			copy_map_value_locked(&smap->map, old_sdata->data,
> -					      value, false);
> +			copy_map_value_locked_user(&smap->map, old_sdata->data,
> +						   value, false, from_user);
>   			return old_sdata;
>   		}
>   	}
> @@ -584,7 +586,7 @@ bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
>   	/* A lookup has just been done before and concluded a new selem is
>   	 * needed. The chance of an unnecessary alloc is unlikely.
>   	 */
> -	alloc_selem = selem = bpf_selem_alloc(smap, owner, value, true, gfp_flags);
> +	alloc_selem = selem = bpf_selem_alloc(smap, owner, value, true, gfp_flags, from_user);
>   	if (!alloc_selem)
>   		return ERR_PTR(-ENOMEM);
>   
> @@ -607,8 +609,8 @@ bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
>   		goto unlock;
>   
>   	if (old_sdata && (map_flags & BPF_F_LOCK)) {
> -		copy_map_value_locked(&smap->map, old_sdata->data, value,
> -				      false);
> +		copy_map_value_locked_user(&smap->map, old_sdata->data, value,
> +					   false, from_user);
>   		selem = SELEM(old_sdata);
>   		goto unlock;
>   	}
> diff --git a/kernel/bpf/helpers.c b/kernel/bpf/helpers.c
> index d02ae323996b..4aef86209fdd 100644
> --- a/kernel/bpf/helpers.c
> +++ b/kernel/bpf/helpers.c
> @@ -372,8 +372,8 @@ const struct bpf_func_proto bpf_spin_unlock_proto = {
>   	.arg1_btf_id    = BPF_PTR_POISON,
>   };
>   
> -void copy_map_value_locked(struct bpf_map *map, void *dst, void *src,
> -			   bool lock_src)
> +void copy_map_value_locked_user(struct bpf_map *map, void *dst, void *src,
> +				bool lock_src, bool from_user)
>   {
>   	struct bpf_spin_lock *lock;
>   
> @@ -383,11 +383,17 @@ void copy_map_value_locked(struct bpf_map *map, void *dst, void *src,
>   		lock = dst + map->record->spin_lock_off;
>   	preempt_disable();
>   	__bpf_spin_lock_irqsave(lock);
> -	copy_map_value(map, dst, src);
> +	copy_map_value_user(map, dst, src, from_user);
>   	__bpf_spin_unlock_irqrestore(lock);
>   	preempt_enable();
>   }
>   
> +void copy_map_value_locked(struct bpf_map *map, void *dst, void *src,
> +			   bool lock_src)
> +{
> +	copy_map_value_locked_user(map, dst, src, lock_src, false);
> +}
> +
>   BPF_CALL_0(bpf_jiffies64)
>   {
>   	return get_jiffies_64();
> diff --git a/kernel/bpf/local_storage.c b/kernel/bpf/local_storage.c
> index 3969eb0382af..62a12fa8ce9e 100644
> --- a/kernel/bpf/local_storage.c
> +++ b/kernel/bpf/local_storage.c
> @@ -147,7 +147,7 @@ static long cgroup_storage_update_elem(struct bpf_map *map, void *key,
>   	struct bpf_cgroup_storage *storage;
>   	struct bpf_storage_buffer *new;
>   
> -	if (unlikely(flags & ~(BPF_F_LOCK | BPF_EXIST)))
> +	if (unlikely(flags & ~BPF_F_LOCK))
>   		return -EINVAL;
>   
>   	if (unlikely((flags & BPF_F_LOCK) &&
> diff --git a/kernel/bpf/syscall.c b/kernel/bpf/syscall.c
> index 90a25307480e..eaa2a9d13265 100644
> --- a/kernel/bpf/syscall.c
> +++ b/kernel/bpf/syscall.c
> @@ -155,8 +155,134 @@ static void maybe_wait_bpf_programs(struct bpf_map *map)
>   		synchronize_rcu();
>   }
>   
> -static int bpf_map_update_value(struct bpf_map *map, struct file *map_file,
> -				void *key, void *value, __u64 flags)
> +static void *trans_addr_pages(struct page **pages, int npages)
> +{
> +	if (npages == 1)
> +		return page_address(pages[0]);
> +	/* For multiple pages, we need to use vmap() to get a contiguous
> +	 * virtual address range.
> +	 */
> +	return vmap(pages, npages, VM_MAP, PAGE_KERNEL);
> +}
> +
> +#define KPTR_USER_MAX_PAGES 16
> +
> +static int bpf_obj_trans_pin_uaddr(struct btf_field *field, void **addr)
> +{
> +	const struct btf_type *t;
> +	struct page *pages[KPTR_USER_MAX_PAGES];
> +	void *ptr, *kern_addr;
> +	u32 type_id, tsz;
> +	int r, npages;
> +
> +	ptr = *addr;
> +	type_id = field->kptr.btf_id;
> +	t = btf_type_id_size(field->kptr.btf, &type_id, &tsz);
> +	if (!t)
> +		return -EINVAL;
> +	if (tsz == 0) {
> +		*addr = NULL;
> +		return 0;
> +	}
> +
> +	npages = (((intptr_t)ptr + tsz + ~PAGE_MASK) -
> +		  ((intptr_t)ptr & PAGE_MASK)) >> PAGE_SHIFT;
> +	if (npages > KPTR_USER_MAX_PAGES)
> +		return -E2BIG;
> +	r = pin_user_pages_fast((intptr_t)ptr & PAGE_MASK, npages, 0, pages);
> +	if (r != npages)
> +		return -EINVAL;
> +	kern_addr = trans_addr_pages(pages, npages);
> +	if (!kern_addr)
> +		return -ENOMEM;
> +	*addr = kern_addr + ((intptr_t)ptr & ~PAGE_MASK);
> +	return 0;
> +}
> +
> +void bpf_obj_unpin_uaddr(const struct btf_field *field, void *addr)
> +{
> +	struct page *pages[KPTR_USER_MAX_PAGES];
> +	int npages, i;
> +	u32 size, type_id;
> +	void *ptr;
> +
> +	type_id = field->kptr.btf_id;
> +	btf_type_id_size(field->kptr.btf, &type_id, &size);
> +	if (size == 0)
> +		return;
> +
> +	ptr = (void *)((intptr_t)addr & PAGE_MASK);
> +	npages = (((intptr_t)addr + size + ~PAGE_MASK) - (intptr_t)ptr) >> PAGE_SHIFT;
> +	for (i = 0; i < npages; i++) {
> +		pages[i] = virt_to_page(ptr);
> +		ptr += PAGE_SIZE;
> +	}
> +	if (npages > 1)
> +		/* Paired with vmap() in trans_addr_pages() */
> +		vunmap((void *)((intptr_t)addr & PAGE_MASK));

Just realize that vunmap() should not be called in a non-sleepable
context. I would add an async variant of vunmap() to defer unmapping to
a workqueue.

> +	unpin_user_pages(pages, npages);
> +}
> +
> +static int bpf_obj_trans_pin_uaddrs(struct btf_record *rec, void *src, u32 size)
> +{
> +	u32 next_off;
> +	int i, err;
> +
> +	if (IS_ERR_OR_NULL(rec))
> +		return 0;
> +
> +	if (!btf_record_has_field(rec, BPF_KPTR_USER))
> +		return 0;
> +
> +	for (i = 0; i < rec->cnt; i++) {
> +		if (rec->fields[i].type != BPF_KPTR_USER)
> +			continue;
> +
> +		next_off = rec->fields[i].offset;
> +		if (next_off + sizeof(void *) > size)
> +			return -EINVAL;
> +		err = bpf_obj_trans_pin_uaddr(&rec->fields[i], src + next_off);
> +		if (!err)
> +			continue;
> +
> +		/* Rollback */
> +		for (i--; i >= 0; i--) {
> +			if (rec->fields[i].type != BPF_KPTR_USER)
> +				continue;
> +			next_off = rec->fields[i].offset;
> +			bpf_obj_unpin_uaddr(&rec->fields[i], *(void **)(src + next_off));
> +			*(void **)(src + next_off) = NULL;
> +		}
> +
> +		return err;
> +	}
> +
> +	return 0;
> +}
> +
> +static void bpf_obj_unpin_uaddrs(struct btf_record *rec, void *src)
> +{
> +	u32 next_off;
> +	int i;
> +
> +	if (IS_ERR_OR_NULL(rec))
> +		return;
> +
> +	if (!btf_record_has_field(rec, BPF_KPTR_USER))
> +		return;
> +
> +	for (i = 0; i < rec->cnt; i++) {
> +		if (rec->fields[i].type != BPF_KPTR_USER)
> +			continue;
> +
> +		next_off = rec->fields[i].offset;
> +		bpf_obj_unpin_uaddr(&rec->fields[i], *(void **)(src + next_off));
> +		*(void **)(src + next_off) = NULL;
> +	}
> +}
> +
> +static int bpf_map_update_value_inner(struct bpf_map *map, struct file *map_file,
> +				      void *key, void *value, __u64 flags)
>   {
>   	int err;
>   
> @@ -208,6 +334,29 @@ static int bpf_map_update_value(struct bpf_map *map, struct file *map_file,
>   	return err;
>   }
>   
> +static int bpf_map_update_value(struct bpf_map *map, struct file *map_file,
> +				void *key, void *value, __u64 flags)
> +{
> +	int err;
> +
> +	if (flags & BPF_FROM_USER) {
> +		/* Pin user memory can lead to context switch, so we need
> +		 * to do it before potential RCU lock.
> +		 */
> +		err = bpf_obj_trans_pin_uaddrs(map->record, value,
> +					       bpf_map_value_size(map));
> +		if (err)
> +			return err;
> +	}
> +
> +	err = bpf_map_update_value_inner(map, map_file, key, value, flags);
> +
> +	if (err && (flags & BPF_FROM_USER))
> +		bpf_obj_unpin_uaddrs(map->record, value);
> +
> +	return err;
> +}
> +
>   static int bpf_map_copy_value(struct bpf_map *map, void *key, void *value,
>   			      __u64 flags)
>   {
> @@ -714,6 +863,11 @@ void bpf_obj_free_fields(const struct btf_record *rec, void *obj)
>   				field->kptr.dtor(xchgd_field);
>   			}
>   			break;
> +		case BPF_KPTR_USER:
> +			if (virt_addr_valid(*(void **)field_ptr))
> +				bpf_obj_unpin_uaddr(field, *(void **)field_ptr);
> +			*(void **)field_ptr = NULL;
> +			break;
>   		case BPF_LIST_HEAD:
>   			if (WARN_ON_ONCE(rec->spin_lock_off < 0))
>   				continue;
> @@ -1155,6 +1309,12 @@ static int map_check_btf(struct bpf_map *map, struct bpf_token *token,
>   					goto free_map_tab;
>   				}
>   				break;
> +			case BPF_KPTR_USER:
> +				if (map->map_type != BPF_MAP_TYPE_TASK_STORAGE) {
> +					ret = -EOPNOTSUPP;
> +					goto free_map_tab;
> +				}
> +				break;
>   			case BPF_LIST_HEAD:
>   			case BPF_RB_ROOT:
>   				if (map->map_type != BPF_MAP_TYPE_HASH &&
> @@ -1618,11 +1778,15 @@ static int map_update_elem(union bpf_attr *attr, bpfptr_t uattr)
>   	struct bpf_map *map;
>   	void *key, *value;
>   	u32 value_size;
> +	u64 extra_flags = 0;
>   	struct fd f;
>   	int err;
>   
>   	if (CHECK_ATTR(BPF_MAP_UPDATE_ELEM))
>   		return -EINVAL;
> +	/* Prevent userspace from setting any internal flags */
> +	if (attr->flags & ~(BIT(BPF_MAP_UPDATE_FLAG_BITS) - 1))
> +		return -EINVAL;
>   
>   	f = fdget(ufd);
>   	map = __bpf_map_get(f);
> @@ -1653,7 +1817,9 @@ static int map_update_elem(union bpf_attr *attr, bpfptr_t uattr)
>   		goto free_key;
>   	}
>   
> -	err = bpf_map_update_value(map, f.file, key, value, attr->flags);
> +	if (map->map_type == BPF_MAP_TYPE_TASK_STORAGE)
> +		extra_flags |= BPF_FROM_USER;
> +	err = bpf_map_update_value(map, f.file, key, value, attr->flags | extra_flags);
>   	if (!err)
>   		maybe_wait_bpf_programs(map);
>   
> @@ -1852,6 +2018,7 @@ int generic_map_update_batch(struct bpf_map *map, struct file *map_file,
>   	void __user *keys = u64_to_user_ptr(attr->batch.keys);
>   	u32 value_size, cp, max_count;
>   	void *key, *value;
> +	u64 extra_flags = 0;
>   	int err = 0;
>   
>   	if (attr->batch.elem_flags & ~BPF_F_LOCK)
> @@ -1881,6 +2048,8 @@ int generic_map_update_batch(struct bpf_map *map, struct file *map_file,
>   		return -ENOMEM;
>   	}
>   
> +	if (map->map_type == BPF_MAP_TYPE_TASK_STORAGE)
> +		extra_flags |= BPF_FROM_USER;
>   	for (cp = 0; cp < max_count; cp++) {
>   		err = -EFAULT;
>   		if (copy_from_user(key, keys + cp * map->key_size,
> @@ -1889,7 +2058,7 @@ int generic_map_update_batch(struct bpf_map *map, struct file *map_file,
>   			break;
>   
>   		err = bpf_map_update_value(map, map_file, key, value,
> -					   attr->batch.elem_flags);
> +					   attr->batch.elem_flags | extra_flags);
>   
>   		if (err)
>   			break;
> diff --git a/net/core/bpf_sk_storage.c b/net/core/bpf_sk_storage.c
> index bc01b3aa6b0f..db5281384e6a 100644
> --- a/net/core/bpf_sk_storage.c
> +++ b/net/core/bpf_sk_storage.c
> @@ -137,7 +137,7 @@ bpf_sk_storage_clone_elem(struct sock *newsk,
>   {
>   	struct bpf_local_storage_elem *copy_selem;
>   
> -	copy_selem = bpf_selem_alloc(smap, newsk, NULL, true, GFP_ATOMIC);
> +	copy_selem = bpf_selem_alloc(smap, newsk, NULL, true, GFP_ATOMIC, false);
>   	if (!copy_selem)
>   		return NULL;
>
Alexei Starovoitov Aug. 12, 2024, 4:45 p.m. UTC | #4
On Wed, Aug 7, 2024 at 4:58 PM Kui-Feng Lee <thinker.li@gmail.com> wrote:
>
> User kptrs are pinned, by pin_user_pages_fast(), and translated to an
> address in the kernel when the value is updated by user programs. (Call
> bpf_map_update_elem() from user programs.) And, the pinned pages are
> unpinned if the value of user kptrs are overritten or if the values of maps
> are deleted/destroyed.
>
> The pages are mapped through vmap() in order to get a continuous space in
> the kernel if the memory pointed by a user kptr resides in two or more
> pages. For the case of single page, page_address() is called to get the
> address of a page in the kernel.
>
> User kptr is only supported by task storage maps.
>
> One user kptr can pin at most KPTR_USER_MAX_PAGES(16) physical pages. This
> is a random picked number for safety. We actually can remove this
> restriction totally.
>
> User kptrs could only be set by user programs through syscalls.  Any
> attempts of updating the value of a map with __kptr_user in it should
> ignore the values of user kptrs from BPF programs. The values of user kptrs
> will keep as they were if the new values are from BPF programs, not from
> user programs.
>
> Cc: linux-mm@kvack.org
> Signed-off-by: Kui-Feng Lee <thinker.li@gmail.com>
> ---
>  include/linux/bpf.h               |  35 +++++-
>  include/linux/bpf_local_storage.h |   2 +-
>  kernel/bpf/bpf_local_storage.c    |  18 +--
>  kernel/bpf/helpers.c              |  12 +-
>  kernel/bpf/local_storage.c        |   2 +-
>  kernel/bpf/syscall.c              | 177 +++++++++++++++++++++++++++++-
>  net/core/bpf_sk_storage.c         |   2 +-
>  7 files changed, 227 insertions(+), 21 deletions(-)
>
> diff --git a/include/linux/bpf.h b/include/linux/bpf.h
> index 87d5f98249e2..f4ad0bc183cb 100644
> --- a/include/linux/bpf.h
> +++ b/include/linux/bpf.h
> @@ -30,6 +30,7 @@
>  #include <linux/static_call.h>
>  #include <linux/memcontrol.h>
>  #include <linux/cfi.h>
> +#include <linux/mm.h>
>
>  struct bpf_verifier_env;
>  struct bpf_verifier_log;
> @@ -477,10 +478,12 @@ static inline void bpf_long_memcpy(void *dst, const void *src, u32 size)
>                 data_race(*ldst++ = *lsrc++);
>  }
>
> +void bpf_obj_unpin_uaddr(const struct btf_field *field, void *addr);
> +
>  /* copy everything but bpf_spin_lock, bpf_timer, and kptrs. There could be one of each. */
>  static inline void bpf_obj_memcpy(struct btf_record *rec,
>                                   void *dst, void *src, u32 size,
> -                                 bool long_memcpy)
> +                                 bool long_memcpy, bool from_user)
>  {
>         u32 curr_off = 0;
>         int i;
> @@ -496,21 +499,40 @@ static inline void bpf_obj_memcpy(struct btf_record *rec,
>         for (i = 0; i < rec->cnt; i++) {
>                 u32 next_off = rec->fields[i].offset;
>                 u32 sz = next_off - curr_off;
> +               void *addr;
>
>                 memcpy(dst + curr_off, src + curr_off, sz);
> +               if (from_user && rec->fields[i].type == BPF_KPTR_USER) {


Do not add this to bpf_obj_memcpy() which is a critical path
for various map operations.
This has to be standalone for task storage only.

> +                       /* Unpin old address.
> +                        *
> +                        * Alignments are guaranteed by btf_find_field_one().
> +                        */
> +                       addr = *(void **)(dst + next_off);
> +                       if (virt_addr_valid(addr))
> +                               bpf_obj_unpin_uaddr(&rec->fields[i], addr);
> +                       else if (addr)
> +                               WARN_ON_ONCE(1);
> +
> +                       *(void **)(dst + next_off) = *(void **)(src + next_off);
> +               }
>                 curr_off += rec->fields[i].size + sz;
>         }
>         memcpy(dst + curr_off, src + curr_off, size - curr_off);
>  }
>
> +static inline void copy_map_value_user(struct bpf_map *map, void *dst, void *src, bool from_user)

No need for these helpers either.

> +{
> +       bpf_obj_memcpy(map->record, dst, src, map->value_size, false, from_user);
> +}
> +
>  static inline void copy_map_value(struct bpf_map *map, void *dst, void *src)
>  {
> -       bpf_obj_memcpy(map->record, dst, src, map->value_size, false);
> +       bpf_obj_memcpy(map->record, dst, src, map->value_size, false, false);
>  }
>
>  static inline void copy_map_value_long(struct bpf_map *map, void *dst, void *src)
>  {
> -       bpf_obj_memcpy(map->record, dst, src, map->value_size, true);
> +       bpf_obj_memcpy(map->record, dst, src, map->value_size, true, false);
>  }
>
>  static inline void bpf_obj_memzero(struct btf_record *rec, void *dst, u32 size)
> @@ -538,6 +560,8 @@ static inline void zero_map_value(struct bpf_map *map, void *dst)
>         bpf_obj_memzero(map->record, dst, map->value_size);
>  }
>
> +void copy_map_value_locked_user(struct bpf_map *map, void *dst, void *src,
> +                               bool lock_src, bool from_user);
>  void copy_map_value_locked(struct bpf_map *map, void *dst, void *src,
>                            bool lock_src);
>  void bpf_timer_cancel_and_free(void *timer);
> @@ -775,6 +799,11 @@ enum bpf_arg_type {
>  };
>  static_assert(__BPF_ARG_TYPE_MAX <= BPF_BASE_TYPE_LIMIT);
>
> +#define BPF_MAP_UPDATE_FLAG_BITS 3
> +enum bpf_map_update_flag {
> +       BPF_FROM_USER = BIT(0 + BPF_MAP_UPDATE_FLAG_BITS)
> +};
> +
>  /* type of values returned from helper functions */
>  enum bpf_return_type {
>         RET_INTEGER,                    /* function returns integer */
> diff --git a/include/linux/bpf_local_storage.h b/include/linux/bpf_local_storage.h
> index dcddb0aef7d8..d337df68fa23 100644
> --- a/include/linux/bpf_local_storage.h
> +++ b/include/linux/bpf_local_storage.h
> @@ -181,7 +181,7 @@ void bpf_selem_link_map(struct bpf_local_storage_map *smap,
>
>  struct bpf_local_storage_elem *
>  bpf_selem_alloc(struct bpf_local_storage_map *smap, void *owner, void *value,
> -               bool charge_mem, gfp_t gfp_flags);
> +               bool charge_mem, gfp_t gfp_flags, bool from_user);
>
>  void bpf_selem_free(struct bpf_local_storage_elem *selem,
>                     struct bpf_local_storage_map *smap,
> diff --git a/kernel/bpf/bpf_local_storage.c b/kernel/bpf/bpf_local_storage.c
> index c938dea5ddbf..c4cf09e27a19 100644
> --- a/kernel/bpf/bpf_local_storage.c
> +++ b/kernel/bpf/bpf_local_storage.c
> @@ -73,7 +73,7 @@ static bool selem_linked_to_map(const struct bpf_local_storage_elem *selem)
>
>  struct bpf_local_storage_elem *
>  bpf_selem_alloc(struct bpf_local_storage_map *smap, void *owner,
> -               void *value, bool charge_mem, gfp_t gfp_flags)
> +               void *value, bool charge_mem, gfp_t gfp_flags, bool from_user)
>  {
>         struct bpf_local_storage_elem *selem;
>
> @@ -100,7 +100,7 @@ bpf_selem_alloc(struct bpf_local_storage_map *smap, void *owner,
>
>         if (selem) {
>                 if (value)
> -                       copy_map_value(&smap->map, SDATA(selem)->data, value);
> +                       copy_map_value_user(&smap->map, SDATA(selem)->data, value, from_user);
>                 /* No need to call check_and_init_map_value as memory is zero init */
>                 return selem;
>         }
> @@ -530,9 +530,11 @@ bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
>         struct bpf_local_storage_elem *alloc_selem, *selem = NULL;
>         struct bpf_local_storage *local_storage;
>         unsigned long flags;
> +       bool from_user = map_flags & BPF_FROM_USER;
>         int err;
>
>         /* BPF_EXIST and BPF_NOEXIST cannot be both set */
> +       map_flags &= ~BPF_FROM_USER;
>         if (unlikely((map_flags & ~BPF_F_LOCK) > BPF_EXIST) ||
>             /* BPF_F_LOCK can only be used in a value with spin_lock */
>             unlikely((map_flags & BPF_F_LOCK) &&
> @@ -550,7 +552,7 @@ bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
>                 if (err)
>                         return ERR_PTR(err);
>
> -               selem = bpf_selem_alloc(smap, owner, value, true, gfp_flags);
> +               selem = bpf_selem_alloc(smap, owner, value, true, gfp_flags, from_user);
>                 if (!selem)
>                         return ERR_PTR(-ENOMEM);
>
> @@ -575,8 +577,8 @@ bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
>                 if (err)
>                         return ERR_PTR(err);
>                 if (old_sdata && selem_linked_to_storage_lockless(SELEM(old_sdata))) {
> -                       copy_map_value_locked(&smap->map, old_sdata->data,
> -                                             value, false);
> +                       copy_map_value_locked_user(&smap->map, old_sdata->data,
> +                                                  value, false, from_user);
>                         return old_sdata;
>                 }
>         }
> @@ -584,7 +586,7 @@ bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
>         /* A lookup has just been done before and concluded a new selem is
>          * needed. The chance of an unnecessary alloc is unlikely.
>          */
> -       alloc_selem = selem = bpf_selem_alloc(smap, owner, value, true, gfp_flags);
> +       alloc_selem = selem = bpf_selem_alloc(smap, owner, value, true, gfp_flags, from_user);
>         if (!alloc_selem)
>                 return ERR_PTR(-ENOMEM);
>
> @@ -607,8 +609,8 @@ bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
>                 goto unlock;
>
>         if (old_sdata && (map_flags & BPF_F_LOCK)) {
> -               copy_map_value_locked(&smap->map, old_sdata->data, value,
> -                                     false);
> +               copy_map_value_locked_user(&smap->map, old_sdata->data, value,
> +                                          false, from_user);
>                 selem = SELEM(old_sdata);
>                 goto unlock;
>         }
> diff --git a/kernel/bpf/helpers.c b/kernel/bpf/helpers.c
> index d02ae323996b..4aef86209fdd 100644
> --- a/kernel/bpf/helpers.c
> +++ b/kernel/bpf/helpers.c
> @@ -372,8 +372,8 @@ const struct bpf_func_proto bpf_spin_unlock_proto = {
>         .arg1_btf_id    = BPF_PTR_POISON,
>  };
>
> -void copy_map_value_locked(struct bpf_map *map, void *dst, void *src,
> -                          bool lock_src)
> +void copy_map_value_locked_user(struct bpf_map *map, void *dst, void *src,
> +                               bool lock_src, bool from_user)
>  {
>         struct bpf_spin_lock *lock;
>
> @@ -383,11 +383,17 @@ void copy_map_value_locked(struct bpf_map *map, void *dst, void *src,
>                 lock = dst + map->record->spin_lock_off;
>         preempt_disable();
>         __bpf_spin_lock_irqsave(lock);
> -       copy_map_value(map, dst, src);
> +       copy_map_value_user(map, dst, src, from_user);
>         __bpf_spin_unlock_irqrestore(lock);
>         preempt_enable();
>  }
>
> +void copy_map_value_locked(struct bpf_map *map, void *dst, void *src,
> +                          bool lock_src)
> +{
> +       copy_map_value_locked_user(map, dst, src, lock_src, false);
> +}
> +
>  BPF_CALL_0(bpf_jiffies64)
>  {
>         return get_jiffies_64();
> diff --git a/kernel/bpf/local_storage.c b/kernel/bpf/local_storage.c
> index 3969eb0382af..62a12fa8ce9e 100644
> --- a/kernel/bpf/local_storage.c
> +++ b/kernel/bpf/local_storage.c
> @@ -147,7 +147,7 @@ static long cgroup_storage_update_elem(struct bpf_map *map, void *key,
>         struct bpf_cgroup_storage *storage;
>         struct bpf_storage_buffer *new;
>
> -       if (unlikely(flags & ~(BPF_F_LOCK | BPF_EXIST)))
> +       if (unlikely(flags & ~BPF_F_LOCK))
>                 return -EINVAL;
>
>         if (unlikely((flags & BPF_F_LOCK) &&
> diff --git a/kernel/bpf/syscall.c b/kernel/bpf/syscall.c
> index 90a25307480e..eaa2a9d13265 100644
> --- a/kernel/bpf/syscall.c
> +++ b/kernel/bpf/syscall.c
> @@ -155,8 +155,134 @@ static void maybe_wait_bpf_programs(struct bpf_map *map)
>                 synchronize_rcu();
>  }
>
> -static int bpf_map_update_value(struct bpf_map *map, struct file *map_file,
> -                               void *key, void *value, __u64 flags)
> +static void *trans_addr_pages(struct page **pages, int npages)
> +{
> +       if (npages == 1)
> +               return page_address(pages[0]);
> +       /* For multiple pages, we need to use vmap() to get a contiguous
> +        * virtual address range.
> +        */
> +       return vmap(pages, npages, VM_MAP, PAGE_KERNEL);
> +}

Don't quite see a need for trans_addr_pages() helper when it's used
once.

> +
> +#define KPTR_USER_MAX_PAGES 16
> +
> +static int bpf_obj_trans_pin_uaddr(struct btf_field *field, void **addr)
> +{
> +       const struct btf_type *t;
> +       struct page *pages[KPTR_USER_MAX_PAGES];
> +       void *ptr, *kern_addr;
> +       u32 type_id, tsz;
> +       int r, npages;
> +
> +       ptr = *addr;
> +       type_id = field->kptr.btf_id;
> +       t = btf_type_id_size(field->kptr.btf, &type_id, &tsz);
> +       if (!t)
> +               return -EINVAL;
> +       if (tsz == 0) {
> +               *addr = NULL;
> +               return 0;
> +       }
> +
> +       npages = (((intptr_t)ptr + tsz + ~PAGE_MASK) -
> +                 ((intptr_t)ptr & PAGE_MASK)) >> PAGE_SHIFT;
> +       if (npages > KPTR_USER_MAX_PAGES)
> +               return -E2BIG;
> +       r = pin_user_pages_fast((intptr_t)ptr & PAGE_MASK, npages, 0, pages);

No need to apply the mask on ptr. See pin_user_pages_fast() internals.

It probably should be FOLL_WRITE | FOLL_LONGTERM instead of 0.

> +       if (r != npages)
> +               return -EINVAL;
> +       kern_addr = trans_addr_pages(pages, npages);
> +       if (!kern_addr)
> +               return -ENOMEM;
> +       *addr = kern_addr + ((intptr_t)ptr & ~PAGE_MASK);
> +       return 0;
> +}
> +
> +void bpf_obj_unpin_uaddr(const struct btf_field *field, void *addr)
> +{
> +       struct page *pages[KPTR_USER_MAX_PAGES];
> +       int npages, i;
> +       u32 size, type_id;
> +       void *ptr;
> +
> +       type_id = field->kptr.btf_id;
> +       btf_type_id_size(field->kptr.btf, &type_id, &size);
> +       if (size == 0)
> +               return;
> +
> +       ptr = (void *)((intptr_t)addr & PAGE_MASK);
> +       npages = (((intptr_t)addr + size + ~PAGE_MASK) - (intptr_t)ptr) >> PAGE_SHIFT;
> +       for (i = 0; i < npages; i++) {
> +               pages[i] = virt_to_page(ptr);
> +               ptr += PAGE_SIZE;
> +       }
> +       if (npages > 1)
> +               /* Paired with vmap() in trans_addr_pages() */
> +               vunmap((void *)((intptr_t)addr & PAGE_MASK));
> +       unpin_user_pages(pages, npages);
> +}
> +
> +static int bpf_obj_trans_pin_uaddrs(struct btf_record *rec, void *src, u32 size)
> +{
> +       u32 next_off;
> +       int i, err;
> +
> +       if (IS_ERR_OR_NULL(rec))
> +               return 0;
> +
> +       if (!btf_record_has_field(rec, BPF_KPTR_USER))
> +               return 0;

imo kptr_user doesn't quite fit as a name.
'kptr' means 'kernel pointer'. Here it's user addr.
Maybe just "uptr" ?

> +
> +       for (i = 0; i < rec->cnt; i++) {
> +               if (rec->fields[i].type != BPF_KPTR_USER)
> +                       continue;
> +
> +               next_off = rec->fields[i].offset;
> +               if (next_off + sizeof(void *) > size)
> +                       return -EINVAL;
> +               err = bpf_obj_trans_pin_uaddr(&rec->fields[i], src + next_off);
> +               if (!err)
> +                       continue;
> +
> +               /* Rollback */
> +               for (i--; i >= 0; i--) {
> +                       if (rec->fields[i].type != BPF_KPTR_USER)
> +                               continue;
> +                       next_off = rec->fields[i].offset;
> +                       bpf_obj_unpin_uaddr(&rec->fields[i], *(void **)(src + next_off));
> +                       *(void **)(src + next_off) = NULL;
> +               }
> +
> +               return err;
> +       }
> +
> +       return 0;
> +}
> +
> +static void bpf_obj_unpin_uaddrs(struct btf_record *rec, void *src)
> +{
> +       u32 next_off;
> +       int i;
> +
> +       if (IS_ERR_OR_NULL(rec))
> +               return;
> +
> +       if (!btf_record_has_field(rec, BPF_KPTR_USER))
> +               return;
> +
> +       for (i = 0; i < rec->cnt; i++) {
> +               if (rec->fields[i].type != BPF_KPTR_USER)
> +                       continue;
> +
> +               next_off = rec->fields[i].offset;
> +               bpf_obj_unpin_uaddr(&rec->fields[i], *(void **)(src + next_off));
> +               *(void **)(src + next_off) = NULL;

This part is pretty much the same as the undo part in
bpf_obj_trans_pin_uaddrs() and the common helper is warranted.

> +       }
> +}
> +
> +static int bpf_map_update_value_inner(struct bpf_map *map, struct file *map_file,
> +                                     void *key, void *value, __u64 flags)
>  {
>         int err;
>
> @@ -208,6 +334,29 @@ static int bpf_map_update_value(struct bpf_map *map, struct file *map_file,
>         return err;
>  }
>
> +static int bpf_map_update_value(struct bpf_map *map, struct file *map_file,
> +                               void *key, void *value, __u64 flags)
> +{
> +       int err;
> +
> +       if (flags & BPF_FROM_USER) {

there shouldn't be a need for this extra flag.
map->record has the info whether uptr is present or not.

> +               /* Pin user memory can lead to context switch, so we need
> +                * to do it before potential RCU lock.
> +                */
> +               err = bpf_obj_trans_pin_uaddrs(map->record, value,
> +                                              bpf_map_value_size(map));
> +               if (err)
> +                       return err;
> +       }
> +
> +       err = bpf_map_update_value_inner(map, map_file, key, value, flags);
> +
> +       if (err && (flags & BPF_FROM_USER))
> +               bpf_obj_unpin_uaddrs(map->record, value);
> +
> +       return err;
> +}
> +
>  static int bpf_map_copy_value(struct bpf_map *map, void *key, void *value,
>                               __u64 flags)
>  {
> @@ -714,6 +863,11 @@ void bpf_obj_free_fields(const struct btf_record *rec, void *obj)
>                                 field->kptr.dtor(xchgd_field);
>                         }
>                         break;
> +               case BPF_KPTR_USER:
> +                       if (virt_addr_valid(*(void **)field_ptr))
> +                               bpf_obj_unpin_uaddr(field, *(void **)field_ptr);
> +                       *(void **)field_ptr = NULL;
> +                       break;
>                 case BPF_LIST_HEAD:
>                         if (WARN_ON_ONCE(rec->spin_lock_off < 0))
>                                 continue;
> @@ -1155,6 +1309,12 @@ static int map_check_btf(struct bpf_map *map, struct bpf_token *token,
>                                         goto free_map_tab;
>                                 }
>                                 break;
> +                       case BPF_KPTR_USER:
> +                               if (map->map_type != BPF_MAP_TYPE_TASK_STORAGE) {
> +                                       ret = -EOPNOTSUPP;
> +                                       goto free_map_tab;
> +                               }
> +                               break;
>                         case BPF_LIST_HEAD:
>                         case BPF_RB_ROOT:
>                                 if (map->map_type != BPF_MAP_TYPE_HASH &&
> @@ -1618,11 +1778,15 @@ static int map_update_elem(union bpf_attr *attr, bpfptr_t uattr)
>         struct bpf_map *map;
>         void *key, *value;
>         u32 value_size;
> +       u64 extra_flags = 0;
>         struct fd f;
>         int err;
>
>         if (CHECK_ATTR(BPF_MAP_UPDATE_ELEM))
>                 return -EINVAL;
> +       /* Prevent userspace from setting any internal flags */
> +       if (attr->flags & ~(BIT(BPF_MAP_UPDATE_FLAG_BITS) - 1))
> +               return -EINVAL;
>
>         f = fdget(ufd);
>         map = __bpf_map_get(f);
> @@ -1653,7 +1817,9 @@ static int map_update_elem(union bpf_attr *attr, bpfptr_t uattr)
>                 goto free_key;
>         }
>
> -       err = bpf_map_update_value(map, f.file, key, value, attr->flags);
> +       if (map->map_type == BPF_MAP_TYPE_TASK_STORAGE)
> +               extra_flags |= BPF_FROM_USER;
> +       err = bpf_map_update_value(map, f.file, key, value, attr->flags | extra_flags);
>         if (!err)
>                 maybe_wait_bpf_programs(map);
>
> @@ -1852,6 +2018,7 @@ int generic_map_update_batch(struct bpf_map *map, struct file *map_file,
>         void __user *keys = u64_to_user_ptr(attr->batch.keys);
>         u32 value_size, cp, max_count;
>         void *key, *value;
> +       u64 extra_flags = 0;
>         int err = 0;
>
>         if (attr->batch.elem_flags & ~BPF_F_LOCK)
> @@ -1881,6 +2048,8 @@ int generic_map_update_batch(struct bpf_map *map, struct file *map_file,
>                 return -ENOMEM;
>         }
>
> +       if (map->map_type == BPF_MAP_TYPE_TASK_STORAGE)
> +               extra_flags |= BPF_FROM_USER;
>         for (cp = 0; cp < max_count; cp++) {
>                 err = -EFAULT;
>                 if (copy_from_user(key, keys + cp * map->key_size,
> @@ -1889,7 +2058,7 @@ int generic_map_update_batch(struct bpf_map *map, struct file *map_file,
>                         break;
>
>                 err = bpf_map_update_value(map, map_file, key, value,
> -                                          attr->batch.elem_flags);
> +                                          attr->batch.elem_flags | extra_flags);
>
>                 if (err)
>                         break;
> diff --git a/net/core/bpf_sk_storage.c b/net/core/bpf_sk_storage.c
> index bc01b3aa6b0f..db5281384e6a 100644
> --- a/net/core/bpf_sk_storage.c
> +++ b/net/core/bpf_sk_storage.c
> @@ -137,7 +137,7 @@ bpf_sk_storage_clone_elem(struct sock *newsk,
>  {
>         struct bpf_local_storage_elem *copy_selem;
>
> -       copy_selem = bpf_selem_alloc(smap, newsk, NULL, true, GFP_ATOMIC);
> +       copy_selem = bpf_selem_alloc(smap, newsk, NULL, true, GFP_ATOMIC, false);
>         if (!copy_selem)
>                 return NULL;
>
> --
> 2.34.1
>
Kui-Feng Lee Aug. 12, 2024, 5:24 p.m. UTC | #5
On 8/12/24 09:45, Alexei Starovoitov wrote:
> On Wed, Aug 7, 2024 at 4:58 PM Kui-Feng Lee <thinker.li@gmail.com> wrote:
>>
>> User kptrs are pinned, by pin_user_pages_fast(), and translated to an
>> address in the kernel when the value is updated by user programs. (Call
>> bpf_map_update_elem() from user programs.) And, the pinned pages are
>> unpinned if the value of user kptrs are overritten or if the values of maps
>> are deleted/destroyed.
>>
>> The pages are mapped through vmap() in order to get a continuous space in
>> the kernel if the memory pointed by a user kptr resides in two or more
>> pages. For the case of single page, page_address() is called to get the
>> address of a page in the kernel.
>>
>> User kptr is only supported by task storage maps.
>>
>> One user kptr can pin at most KPTR_USER_MAX_PAGES(16) physical pages. This
>> is a random picked number for safety. We actually can remove this
>> restriction totally.
>>
>> User kptrs could only be set by user programs through syscalls.  Any
>> attempts of updating the value of a map with __kptr_user in it should
>> ignore the values of user kptrs from BPF programs. The values of user kptrs
>> will keep as they were if the new values are from BPF programs, not from
>> user programs.
>>
>> Cc: linux-mm@kvack.org
>> Signed-off-by: Kui-Feng Lee <thinker.li@gmail.com>
>> ---
>>   include/linux/bpf.h               |  35 +++++-
>>   include/linux/bpf_local_storage.h |   2 +-
>>   kernel/bpf/bpf_local_storage.c    |  18 +--
>>   kernel/bpf/helpers.c              |  12 +-
>>   kernel/bpf/local_storage.c        |   2 +-
>>   kernel/bpf/syscall.c              | 177 +++++++++++++++++++++++++++++-
>>   net/core/bpf_sk_storage.c         |   2 +-
>>   7 files changed, 227 insertions(+), 21 deletions(-)
>>
>> diff --git a/include/linux/bpf.h b/include/linux/bpf.h
>> index 87d5f98249e2..f4ad0bc183cb 100644
>> --- a/include/linux/bpf.h
>> +++ b/include/linux/bpf.h
>> @@ -30,6 +30,7 @@
>>   #include <linux/static_call.h>
>>   #include <linux/memcontrol.h>
>>   #include <linux/cfi.h>
>> +#include <linux/mm.h>
>>
>>   struct bpf_verifier_env;
>>   struct bpf_verifier_log;
>> @@ -477,10 +478,12 @@ static inline void bpf_long_memcpy(void *dst, const void *src, u32 size)
>>                  data_race(*ldst++ = *lsrc++);
>>   }
>>
>> +void bpf_obj_unpin_uaddr(const struct btf_field *field, void *addr);
>> +
>>   /* copy everything but bpf_spin_lock, bpf_timer, and kptrs. There could be one of each. */
>>   static inline void bpf_obj_memcpy(struct btf_record *rec,
>>                                    void *dst, void *src, u32 size,
>> -                                 bool long_memcpy)
>> +                                 bool long_memcpy, bool from_user)
>>   {
>>          u32 curr_off = 0;
>>          int i;
>> @@ -496,21 +499,40 @@ static inline void bpf_obj_memcpy(struct btf_record *rec,
>>          for (i = 0; i < rec->cnt; i++) {
>>                  u32 next_off = rec->fields[i].offset;
>>                  u32 sz = next_off - curr_off;
>> +               void *addr;
>>
>>                  memcpy(dst + curr_off, src + curr_off, sz);
>> +               if (from_user && rec->fields[i].type == BPF_KPTR_USER) {
> 
> 
> Do not add this to bpf_obj_memcpy() which is a critical path
> for various map operations.
> This has to be standalone for task storage only.
> 
>> +                       /* Unpin old address.
>> +                        *
>> +                        * Alignments are guaranteed by btf_find_field_one().
>> +                        */
>> +                       addr = *(void **)(dst + next_off);
>> +                       if (virt_addr_valid(addr))
>> +                               bpf_obj_unpin_uaddr(&rec->fields[i], addr);
>> +                       else if (addr)
>> +                               WARN_ON_ONCE(1);
>> +
>> +                       *(void **)(dst + next_off) = *(void **)(src + next_off);
>> +               }
>>                  curr_off += rec->fields[i].size + sz;
>>          }
>>          memcpy(dst + curr_off, src + curr_off, size - curr_off);
>>   }
>>
>> +static inline void copy_map_value_user(struct bpf_map *map, void *dst, void *src, bool from_user)
> 
> No need for these helpers either.
> 
>> +{
>> +       bpf_obj_memcpy(map->record, dst, src, map->value_size, false, from_user);
>> +}
>> +
>>   static inline void copy_map_value(struct bpf_map *map, void *dst, void *src)
>>   {
>> -       bpf_obj_memcpy(map->record, dst, src, map->value_size, false);
>> +       bpf_obj_memcpy(map->record, dst, src, map->value_size, false, false);
>>   }
>>
>>   static inline void copy_map_value_long(struct bpf_map *map, void *dst, void *src)
>>   {
>> -       bpf_obj_memcpy(map->record, dst, src, map->value_size, true);
>> +       bpf_obj_memcpy(map->record, dst, src, map->value_size, true, false);
>>   }
>>
>>   static inline void bpf_obj_memzero(struct btf_record *rec, void *dst, u32 size)
>> @@ -538,6 +560,8 @@ static inline void zero_map_value(struct bpf_map *map, void *dst)
>>          bpf_obj_memzero(map->record, dst, map->value_size);
>>   }
>>
>> +void copy_map_value_locked_user(struct bpf_map *map, void *dst, void *src,
>> +                               bool lock_src, bool from_user);
>>   void copy_map_value_locked(struct bpf_map *map, void *dst, void *src,
>>                             bool lock_src);
>>   void bpf_timer_cancel_and_free(void *timer);
>> @@ -775,6 +799,11 @@ enum bpf_arg_type {
>>   };
>>   static_assert(__BPF_ARG_TYPE_MAX <= BPF_BASE_TYPE_LIMIT);
>>
>> +#define BPF_MAP_UPDATE_FLAG_BITS 3
>> +enum bpf_map_update_flag {
>> +       BPF_FROM_USER = BIT(0 + BPF_MAP_UPDATE_FLAG_BITS)
>> +};
>> +
>>   /* type of values returned from helper functions */
>>   enum bpf_return_type {
>>          RET_INTEGER,                    /* function returns integer */
>> diff --git a/include/linux/bpf_local_storage.h b/include/linux/bpf_local_storage.h
>> index dcddb0aef7d8..d337df68fa23 100644
>> --- a/include/linux/bpf_local_storage.h
>> +++ b/include/linux/bpf_local_storage.h
>> @@ -181,7 +181,7 @@ void bpf_selem_link_map(struct bpf_local_storage_map *smap,
>>
>>   struct bpf_local_storage_elem *
>>   bpf_selem_alloc(struct bpf_local_storage_map *smap, void *owner, void *value,
>> -               bool charge_mem, gfp_t gfp_flags);
>> +               bool charge_mem, gfp_t gfp_flags, bool from_user);
>>
>>   void bpf_selem_free(struct bpf_local_storage_elem *selem,
>>                      struct bpf_local_storage_map *smap,
>> diff --git a/kernel/bpf/bpf_local_storage.c b/kernel/bpf/bpf_local_storage.c
>> index c938dea5ddbf..c4cf09e27a19 100644
>> --- a/kernel/bpf/bpf_local_storage.c
>> +++ b/kernel/bpf/bpf_local_storage.c
>> @@ -73,7 +73,7 @@ static bool selem_linked_to_map(const struct bpf_local_storage_elem *selem)
>>
>>   struct bpf_local_storage_elem *
>>   bpf_selem_alloc(struct bpf_local_storage_map *smap, void *owner,
>> -               void *value, bool charge_mem, gfp_t gfp_flags)
>> +               void *value, bool charge_mem, gfp_t gfp_flags, bool from_user)
>>   {
>>          struct bpf_local_storage_elem *selem;
>>
>> @@ -100,7 +100,7 @@ bpf_selem_alloc(struct bpf_local_storage_map *smap, void *owner,
>>
>>          if (selem) {
>>                  if (value)
>> -                       copy_map_value(&smap->map, SDATA(selem)->data, value);
>> +                       copy_map_value_user(&smap->map, SDATA(selem)->data, value, from_user);
>>                  /* No need to call check_and_init_map_value as memory is zero init */
>>                  return selem;
>>          }
>> @@ -530,9 +530,11 @@ bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
>>          struct bpf_local_storage_elem *alloc_selem, *selem = NULL;
>>          struct bpf_local_storage *local_storage;
>>          unsigned long flags;
>> +       bool from_user = map_flags & BPF_FROM_USER;
>>          int err;
>>
>>          /* BPF_EXIST and BPF_NOEXIST cannot be both set */
>> +       map_flags &= ~BPF_FROM_USER;
>>          if (unlikely((map_flags & ~BPF_F_LOCK) > BPF_EXIST) ||
>>              /* BPF_F_LOCK can only be used in a value with spin_lock */
>>              unlikely((map_flags & BPF_F_LOCK) &&
>> @@ -550,7 +552,7 @@ bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
>>                  if (err)
>>                          return ERR_PTR(err);
>>
>> -               selem = bpf_selem_alloc(smap, owner, value, true, gfp_flags);
>> +               selem = bpf_selem_alloc(smap, owner, value, true, gfp_flags, from_user);
>>                  if (!selem)
>>                          return ERR_PTR(-ENOMEM);
>>
>> @@ -575,8 +577,8 @@ bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
>>                  if (err)
>>                          return ERR_PTR(err);
>>                  if (old_sdata && selem_linked_to_storage_lockless(SELEM(old_sdata))) {
>> -                       copy_map_value_locked(&smap->map, old_sdata->data,
>> -                                             value, false);
>> +                       copy_map_value_locked_user(&smap->map, old_sdata->data,
>> +                                                  value, false, from_user);
>>                          return old_sdata;
>>                  }
>>          }
>> @@ -584,7 +586,7 @@ bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
>>          /* A lookup has just been done before and concluded a new selem is
>>           * needed. The chance of an unnecessary alloc is unlikely.
>>           */
>> -       alloc_selem = selem = bpf_selem_alloc(smap, owner, value, true, gfp_flags);
>> +       alloc_selem = selem = bpf_selem_alloc(smap, owner, value, true, gfp_flags, from_user);
>>          if (!alloc_selem)
>>                  return ERR_PTR(-ENOMEM);
>>
>> @@ -607,8 +609,8 @@ bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
>>                  goto unlock;
>>
>>          if (old_sdata && (map_flags & BPF_F_LOCK)) {
>> -               copy_map_value_locked(&smap->map, old_sdata->data, value,
>> -                                     false);
>> +               copy_map_value_locked_user(&smap->map, old_sdata->data, value,
>> +                                          false, from_user);
>>                  selem = SELEM(old_sdata);
>>                  goto unlock;
>>          }
>> diff --git a/kernel/bpf/helpers.c b/kernel/bpf/helpers.c
>> index d02ae323996b..4aef86209fdd 100644
>> --- a/kernel/bpf/helpers.c
>> +++ b/kernel/bpf/helpers.c
>> @@ -372,8 +372,8 @@ const struct bpf_func_proto bpf_spin_unlock_proto = {
>>          .arg1_btf_id    = BPF_PTR_POISON,
>>   };
>>
>> -void copy_map_value_locked(struct bpf_map *map, void *dst, void *src,
>> -                          bool lock_src)
>> +void copy_map_value_locked_user(struct bpf_map *map, void *dst, void *src,
>> +                               bool lock_src, bool from_user)
>>   {
>>          struct bpf_spin_lock *lock;
>>
>> @@ -383,11 +383,17 @@ void copy_map_value_locked(struct bpf_map *map, void *dst, void *src,
>>                  lock = dst + map->record->spin_lock_off;
>>          preempt_disable();
>>          __bpf_spin_lock_irqsave(lock);
>> -       copy_map_value(map, dst, src);
>> +       copy_map_value_user(map, dst, src, from_user);
>>          __bpf_spin_unlock_irqrestore(lock);
>>          preempt_enable();
>>   }
>>
>> +void copy_map_value_locked(struct bpf_map *map, void *dst, void *src,
>> +                          bool lock_src)
>> +{
>> +       copy_map_value_locked_user(map, dst, src, lock_src, false);
>> +}
>> +
>>   BPF_CALL_0(bpf_jiffies64)
>>   {
>>          return get_jiffies_64();
>> diff --git a/kernel/bpf/local_storage.c b/kernel/bpf/local_storage.c
>> index 3969eb0382af..62a12fa8ce9e 100644
>> --- a/kernel/bpf/local_storage.c
>> +++ b/kernel/bpf/local_storage.c
>> @@ -147,7 +147,7 @@ static long cgroup_storage_update_elem(struct bpf_map *map, void *key,
>>          struct bpf_cgroup_storage *storage;
>>          struct bpf_storage_buffer *new;
>>
>> -       if (unlikely(flags & ~(BPF_F_LOCK | BPF_EXIST)))
>> +       if (unlikely(flags & ~BPF_F_LOCK))
>>                  return -EINVAL;
>>
>>          if (unlikely((flags & BPF_F_LOCK) &&
>> diff --git a/kernel/bpf/syscall.c b/kernel/bpf/syscall.c
>> index 90a25307480e..eaa2a9d13265 100644
>> --- a/kernel/bpf/syscall.c
>> +++ b/kernel/bpf/syscall.c
>> @@ -155,8 +155,134 @@ static void maybe_wait_bpf_programs(struct bpf_map *map)
>>                  synchronize_rcu();
>>   }
>>
>> -static int bpf_map_update_value(struct bpf_map *map, struct file *map_file,
>> -                               void *key, void *value, __u64 flags)
>> +static void *trans_addr_pages(struct page **pages, int npages)
>> +{
>> +       if (npages == 1)
>> +               return page_address(pages[0]);
>> +       /* For multiple pages, we need to use vmap() to get a contiguous
>> +        * virtual address range.
>> +        */
>> +       return vmap(pages, npages, VM_MAP, PAGE_KERNEL);
>> +}
> 
> Don't quite see a need for trans_addr_pages() helper when it's used
> once.
> 
>> +
>> +#define KPTR_USER_MAX_PAGES 16
>> +
>> +static int bpf_obj_trans_pin_uaddr(struct btf_field *field, void **addr)
>> +{
>> +       const struct btf_type *t;
>> +       struct page *pages[KPTR_USER_MAX_PAGES];
>> +       void *ptr, *kern_addr;
>> +       u32 type_id, tsz;
>> +       int r, npages;
>> +
>> +       ptr = *addr;
>> +       type_id = field->kptr.btf_id;
>> +       t = btf_type_id_size(field->kptr.btf, &type_id, &tsz);
>> +       if (!t)
>> +               return -EINVAL;
>> +       if (tsz == 0) {
>> +               *addr = NULL;
>> +               return 0;
>> +       }
>> +
>> +       npages = (((intptr_t)ptr + tsz + ~PAGE_MASK) -
>> +                 ((intptr_t)ptr & PAGE_MASK)) >> PAGE_SHIFT;
>> +       if (npages > KPTR_USER_MAX_PAGES)
>> +               return -E2BIG;
>> +       r = pin_user_pages_fast((intptr_t)ptr & PAGE_MASK, npages, 0, pages);
> 
> No need to apply the mask on ptr. See pin_user_pages_fast() internals.
> 
> It probably should be FOLL_WRITE | FOLL_LONGTERM instead of 0.

Agree!

> 
>> +       if (r != npages)
>> +               return -EINVAL;
>> +       kern_addr = trans_addr_pages(pages, npages);
>> +       if (!kern_addr)
>> +               return -ENOMEM;
>> +       *addr = kern_addr + ((intptr_t)ptr & ~PAGE_MASK);
>> +       return 0;
>> +}
>> +
>> +void bpf_obj_unpin_uaddr(const struct btf_field *field, void *addr)
>> +{
>> +       struct page *pages[KPTR_USER_MAX_PAGES];
>> +       int npages, i;
>> +       u32 size, type_id;
>> +       void *ptr;
>> +
>> +       type_id = field->kptr.btf_id;
>> +       btf_type_id_size(field->kptr.btf, &type_id, &size);
>> +       if (size == 0)
>> +               return;
>> +
>> +       ptr = (void *)((intptr_t)addr & PAGE_MASK);
>> +       npages = (((intptr_t)addr + size + ~PAGE_MASK) - (intptr_t)ptr) >> PAGE_SHIFT;
>> +       for (i = 0; i < npages; i++) {
>> +               pages[i] = virt_to_page(ptr);
>> +               ptr += PAGE_SIZE;
>> +       }
>> +       if (npages > 1)
>> +               /* Paired with vmap() in trans_addr_pages() */
>> +               vunmap((void *)((intptr_t)addr & PAGE_MASK));
>> +       unpin_user_pages(pages, npages);
>> +}
>> +
>> +static int bpf_obj_trans_pin_uaddrs(struct btf_record *rec, void *src, u32 size)
>> +{
>> +       u32 next_off;
>> +       int i, err;
>> +
>> +       if (IS_ERR_OR_NULL(rec))
>> +               return 0;
>> +
>> +       if (!btf_record_has_field(rec, BPF_KPTR_USER))
>> +               return 0;
> 
> imo kptr_user doesn't quite fit as a name.
> 'kptr' means 'kernel pointer'. Here it's user addr.
> Maybe just "uptr" ?

That makes sense.

> 
>> +
>> +       for (i = 0; i < rec->cnt; i++) {
>> +               if (rec->fields[i].type != BPF_KPTR_USER)
>> +                       continue;
>> +
>> +               next_off = rec->fields[i].offset;
>> +               if (next_off + sizeof(void *) > size)
>> +                       return -EINVAL;
>> +               err = bpf_obj_trans_pin_uaddr(&rec->fields[i], src + next_off);
>> +               if (!err)
>> +                       continue;
>> +
>> +               /* Rollback */
>> +               for (i--; i >= 0; i--) {
>> +                       if (rec->fields[i].type != BPF_KPTR_USER)
>> +                               continue;
>> +                       next_off = rec->fields[i].offset;
>> +                       bpf_obj_unpin_uaddr(&rec->fields[i], *(void **)(src + next_off));
>> +                       *(void **)(src + next_off) = NULL;
>> +               }
>> +
>> +               return err;
>> +       }
>> +
>> +       return 0;
>> +}
>> +
>> +static void bpf_obj_unpin_uaddrs(struct btf_record *rec, void *src)
>> +{
>> +       u32 next_off;
>> +       int i;
>> +
>> +       if (IS_ERR_OR_NULL(rec))
>> +               return;
>> +
>> +       if (!btf_record_has_field(rec, BPF_KPTR_USER))
>> +               return;
>> +
>> +       for (i = 0; i < rec->cnt; i++) {
>> +               if (rec->fields[i].type != BPF_KPTR_USER)
>> +                       continue;
>> +
>> +               next_off = rec->fields[i].offset;
>> +               bpf_obj_unpin_uaddr(&rec->fields[i], *(void **)(src + next_off));
>> +               *(void **)(src + next_off) = NULL;
> 
> This part is pretty much the same as the undo part in
> bpf_obj_trans_pin_uaddrs() and the common helper is warranted.

Sure!

> 
>> +       }
>> +}
>> +
>> +static int bpf_map_update_value_inner(struct bpf_map *map, struct file *map_file,
>> +                                     void *key, void *value, __u64 flags)
>>   {
>>          int err;
>>
>> @@ -208,6 +334,29 @@ static int bpf_map_update_value(struct bpf_map *map, struct file *map_file,
>>          return err;
>>   }
>>
>> +static int bpf_map_update_value(struct bpf_map *map, struct file *map_file,
>> +                               void *key, void *value, __u64 flags)
>> +{
>> +       int err;
>> +
>> +       if (flags & BPF_FROM_USER) {
> 
> there shouldn't be a need for this extra flag.
> map->record has the info whether uptr is present or not.

The BPF_FROM_USER flag is used to support updating map values from BPF
programs as well. Although BPF programs can udpate map values, I
don't want the values of uptrs to be changed by the BPF programs.

Should we just forbid the BPF programs to udpate the map values having
uptrs in them?


> 
>> +               /* Pin user memory can lead to context switch, so we need
>> +                * to do it before potential RCU lock.
>> +                */
>> +               err = bpf_obj_trans_pin_uaddrs(map->record, value,
>> +                                              bpf_map_value_size(map));
>> +               if (err)
>> +                       return err;
>> +       }
>> +
>> +       err = bpf_map_update_value_inner(map, map_file, key, value, flags);
>> +
>> +       if (err && (flags & BPF_FROM_USER))
>> +               bpf_obj_unpin_uaddrs(map->record, value);
>> +
>> +       return err;
>> +}
>> +
>>   static int bpf_map_copy_value(struct bpf_map *map, void *key, void *value,
>>                                __u64 flags)
>>   {
>> @@ -714,6 +863,11 @@ void bpf_obj_free_fields(const struct btf_record *rec, void *obj)
>>                                  field->kptr.dtor(xchgd_field);
>>                          }
>>                          break;
>> +               case BPF_KPTR_USER:
>> +                       if (virt_addr_valid(*(void **)field_ptr))
>> +                               bpf_obj_unpin_uaddr(field, *(void **)field_ptr);
>> +                       *(void **)field_ptr = NULL;
>> +                       break;
>>                  case BPF_LIST_HEAD:
>>                          if (WARN_ON_ONCE(rec->spin_lock_off < 0))
>>                                  continue;
>> @@ -1155,6 +1309,12 @@ static int map_check_btf(struct bpf_map *map, struct bpf_token *token,
>>                                          goto free_map_tab;
>>                                  }
>>                                  break;
>> +                       case BPF_KPTR_USER:
>> +                               if (map->map_type != BPF_MAP_TYPE_TASK_STORAGE) {
>> +                                       ret = -EOPNOTSUPP;
>> +                                       goto free_map_tab;
>> +                               }
>> +                               break;
>>                          case BPF_LIST_HEAD:
>>                          case BPF_RB_ROOT:
>>                                  if (map->map_type != BPF_MAP_TYPE_HASH &&
>> @@ -1618,11 +1778,15 @@ static int map_update_elem(union bpf_attr *attr, bpfptr_t uattr)
>>          struct bpf_map *map;
>>          void *key, *value;
>>          u32 value_size;
>> +       u64 extra_flags = 0;
>>          struct fd f;
>>          int err;
>>
>>          if (CHECK_ATTR(BPF_MAP_UPDATE_ELEM))
>>                  return -EINVAL;
>> +       /* Prevent userspace from setting any internal flags */
>> +       if (attr->flags & ~(BIT(BPF_MAP_UPDATE_FLAG_BITS) - 1))
>> +               return -EINVAL;
>>
>>          f = fdget(ufd);
>>          map = __bpf_map_get(f);
>> @@ -1653,7 +1817,9 @@ static int map_update_elem(union bpf_attr *attr, bpfptr_t uattr)
>>                  goto free_key;
>>          }
>>
>> -       err = bpf_map_update_value(map, f.file, key, value, attr->flags);
>> +       if (map->map_type == BPF_MAP_TYPE_TASK_STORAGE)
>> +               extra_flags |= BPF_FROM_USER;
>> +       err = bpf_map_update_value(map, f.file, key, value, attr->flags | extra_flags);
>>          if (!err)
>>                  maybe_wait_bpf_programs(map);
>>
>> @@ -1852,6 +2018,7 @@ int generic_map_update_batch(struct bpf_map *map, struct file *map_file,
>>          void __user *keys = u64_to_user_ptr(attr->batch.keys);
>>          u32 value_size, cp, max_count;
>>          void *key, *value;
>> +       u64 extra_flags = 0;
>>          int err = 0;
>>
>>          if (attr->batch.elem_flags & ~BPF_F_LOCK)
>> @@ -1881,6 +2048,8 @@ int generic_map_update_batch(struct bpf_map *map, struct file *map_file,
>>                  return -ENOMEM;
>>          }
>>
>> +       if (map->map_type == BPF_MAP_TYPE_TASK_STORAGE)
>> +               extra_flags |= BPF_FROM_USER;
>>          for (cp = 0; cp < max_count; cp++) {
>>                  err = -EFAULT;
>>                  if (copy_from_user(key, keys + cp * map->key_size,
>> @@ -1889,7 +2058,7 @@ int generic_map_update_batch(struct bpf_map *map, struct file *map_file,
>>                          break;
>>
>>                  err = bpf_map_update_value(map, map_file, key, value,
>> -                                          attr->batch.elem_flags);
>> +                                          attr->batch.elem_flags | extra_flags);
>>
>>                  if (err)
>>                          break;
>> diff --git a/net/core/bpf_sk_storage.c b/net/core/bpf_sk_storage.c
>> index bc01b3aa6b0f..db5281384e6a 100644
>> --- a/net/core/bpf_sk_storage.c
>> +++ b/net/core/bpf_sk_storage.c
>> @@ -137,7 +137,7 @@ bpf_sk_storage_clone_elem(struct sock *newsk,
>>   {
>>          struct bpf_local_storage_elem *copy_selem;
>>
>> -       copy_selem = bpf_selem_alloc(smap, newsk, NULL, true, GFP_ATOMIC);
>> +       copy_selem = bpf_selem_alloc(smap, newsk, NULL, true, GFP_ATOMIC, false);
>>          if (!copy_selem)
>>                  return NULL;
>>
>> --
>> 2.34.1
>>
Alexei Starovoitov Aug. 12, 2024, 5:36 p.m. UTC | #6
On Mon, Aug 12, 2024 at 10:24 AM Kui-Feng Lee <sinquersw@gmail.com> wrote:
>
>
> >> +static int bpf_map_update_value(struct bpf_map *map, struct file *map_file,
> >> +                               void *key, void *value, __u64 flags)
> >> +{
> >> +       int err;
> >> +
> >> +       if (flags & BPF_FROM_USER) {
> >
> > there shouldn't be a need for this extra flag.
> > map->record has the info whether uptr is present or not.
>
> The BPF_FROM_USER flag is used to support updating map values from BPF
> programs as well. Although BPF programs can udpate map values, I
> don't want the values of uptrs to be changed by the BPF programs.
>
> Should we just forbid the BPF programs to udpate the map values having
> uptrs in them?

hmm. map_update_elem() is disallowed from bpf prog.

        case BPF_MAP_TYPE_TASK_STORAGE:
                if (func_id != BPF_FUNC_task_storage_get &&
                    func_id != BPF_FUNC_task_storage_delete &&
                    func_id != BPF_FUNC_kptr_xchg)
                        goto error;
Kui-Feng Lee Aug. 12, 2024, 5:51 p.m. UTC | #7
On 8/12/24 10:36, Alexei Starovoitov wrote:
> On Mon, Aug 12, 2024 at 10:24 AM Kui-Feng Lee <sinquersw@gmail.com> wrote:
>>
>>
>>>> +static int bpf_map_update_value(struct bpf_map *map, struct file *map_file,
>>>> +                               void *key, void *value, __u64 flags)
>>>> +{
>>>> +       int err;
>>>> +
>>>> +       if (flags & BPF_FROM_USER) {
>>>
>>> there shouldn't be a need for this extra flag.
>>> map->record has the info whether uptr is present or not.
>>
>> The BPF_FROM_USER flag is used to support updating map values from BPF
>> programs as well. Although BPF programs can udpate map values, I
>> don't want the values of uptrs to be changed by the BPF programs.
>>
>> Should we just forbid the BPF programs to udpate the map values having
>> uptrs in them?
> 
> hmm. map_update_elem() is disallowed from bpf prog.
> 
>          case BPF_MAP_TYPE_TASK_STORAGE:
>                  if (func_id != BPF_FUNC_task_storage_get &&
>                      func_id != BPF_FUNC_task_storage_delete &&
>                      func_id != BPF_FUNC_kptr_xchg)
>                          goto error;

Thank you for the information!
diff mbox series

Patch

diff --git a/include/linux/bpf.h b/include/linux/bpf.h
index 87d5f98249e2..f4ad0bc183cb 100644
--- a/include/linux/bpf.h
+++ b/include/linux/bpf.h
@@ -30,6 +30,7 @@ 
 #include <linux/static_call.h>
 #include <linux/memcontrol.h>
 #include <linux/cfi.h>
+#include <linux/mm.h>
 
 struct bpf_verifier_env;
 struct bpf_verifier_log;
@@ -477,10 +478,12 @@  static inline void bpf_long_memcpy(void *dst, const void *src, u32 size)
 		data_race(*ldst++ = *lsrc++);
 }
 
+void bpf_obj_unpin_uaddr(const struct btf_field *field, void *addr);
+
 /* copy everything but bpf_spin_lock, bpf_timer, and kptrs. There could be one of each. */
 static inline void bpf_obj_memcpy(struct btf_record *rec,
 				  void *dst, void *src, u32 size,
-				  bool long_memcpy)
+				  bool long_memcpy, bool from_user)
 {
 	u32 curr_off = 0;
 	int i;
@@ -496,21 +499,40 @@  static inline void bpf_obj_memcpy(struct btf_record *rec,
 	for (i = 0; i < rec->cnt; i++) {
 		u32 next_off = rec->fields[i].offset;
 		u32 sz = next_off - curr_off;
+		void *addr;
 
 		memcpy(dst + curr_off, src + curr_off, sz);
+		if (from_user && rec->fields[i].type == BPF_KPTR_USER) {
+			/* Unpin old address.
+			 *
+			 * Alignments are guaranteed by btf_find_field_one().
+			 */
+			addr = *(void **)(dst + next_off);
+			if (virt_addr_valid(addr))
+				bpf_obj_unpin_uaddr(&rec->fields[i], addr);
+			else if (addr)
+				WARN_ON_ONCE(1);
+
+			*(void **)(dst + next_off) = *(void **)(src + next_off);
+		}
 		curr_off += rec->fields[i].size + sz;
 	}
 	memcpy(dst + curr_off, src + curr_off, size - curr_off);
 }
 
+static inline void copy_map_value_user(struct bpf_map *map, void *dst, void *src, bool from_user)
+{
+	bpf_obj_memcpy(map->record, dst, src, map->value_size, false, from_user);
+}
+
 static inline void copy_map_value(struct bpf_map *map, void *dst, void *src)
 {
-	bpf_obj_memcpy(map->record, dst, src, map->value_size, false);
+	bpf_obj_memcpy(map->record, dst, src, map->value_size, false, false);
 }
 
 static inline void copy_map_value_long(struct bpf_map *map, void *dst, void *src)
 {
-	bpf_obj_memcpy(map->record, dst, src, map->value_size, true);
+	bpf_obj_memcpy(map->record, dst, src, map->value_size, true, false);
 }
 
 static inline void bpf_obj_memzero(struct btf_record *rec, void *dst, u32 size)
@@ -538,6 +560,8 @@  static inline void zero_map_value(struct bpf_map *map, void *dst)
 	bpf_obj_memzero(map->record, dst, map->value_size);
 }
 
+void copy_map_value_locked_user(struct bpf_map *map, void *dst, void *src,
+				bool lock_src, bool from_user);
 void copy_map_value_locked(struct bpf_map *map, void *dst, void *src,
 			   bool lock_src);
 void bpf_timer_cancel_and_free(void *timer);
@@ -775,6 +799,11 @@  enum bpf_arg_type {
 };
 static_assert(__BPF_ARG_TYPE_MAX <= BPF_BASE_TYPE_LIMIT);
 
+#define BPF_MAP_UPDATE_FLAG_BITS 3
+enum bpf_map_update_flag {
+	BPF_FROM_USER = BIT(0 + BPF_MAP_UPDATE_FLAG_BITS)
+};
+
 /* type of values returned from helper functions */
 enum bpf_return_type {
 	RET_INTEGER,			/* function returns integer */
diff --git a/include/linux/bpf_local_storage.h b/include/linux/bpf_local_storage.h
index dcddb0aef7d8..d337df68fa23 100644
--- a/include/linux/bpf_local_storage.h
+++ b/include/linux/bpf_local_storage.h
@@ -181,7 +181,7 @@  void bpf_selem_link_map(struct bpf_local_storage_map *smap,
 
 struct bpf_local_storage_elem *
 bpf_selem_alloc(struct bpf_local_storage_map *smap, void *owner, void *value,
-		bool charge_mem, gfp_t gfp_flags);
+		bool charge_mem, gfp_t gfp_flags, bool from_user);
 
 void bpf_selem_free(struct bpf_local_storage_elem *selem,
 		    struct bpf_local_storage_map *smap,
diff --git a/kernel/bpf/bpf_local_storage.c b/kernel/bpf/bpf_local_storage.c
index c938dea5ddbf..c4cf09e27a19 100644
--- a/kernel/bpf/bpf_local_storage.c
+++ b/kernel/bpf/bpf_local_storage.c
@@ -73,7 +73,7 @@  static bool selem_linked_to_map(const struct bpf_local_storage_elem *selem)
 
 struct bpf_local_storage_elem *
 bpf_selem_alloc(struct bpf_local_storage_map *smap, void *owner,
-		void *value, bool charge_mem, gfp_t gfp_flags)
+		void *value, bool charge_mem, gfp_t gfp_flags, bool from_user)
 {
 	struct bpf_local_storage_elem *selem;
 
@@ -100,7 +100,7 @@  bpf_selem_alloc(struct bpf_local_storage_map *smap, void *owner,
 
 	if (selem) {
 		if (value)
-			copy_map_value(&smap->map, SDATA(selem)->data, value);
+			copy_map_value_user(&smap->map, SDATA(selem)->data, value, from_user);
 		/* No need to call check_and_init_map_value as memory is zero init */
 		return selem;
 	}
@@ -530,9 +530,11 @@  bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
 	struct bpf_local_storage_elem *alloc_selem, *selem = NULL;
 	struct bpf_local_storage *local_storage;
 	unsigned long flags;
+	bool from_user = map_flags & BPF_FROM_USER;
 	int err;
 
 	/* BPF_EXIST and BPF_NOEXIST cannot be both set */
+	map_flags &= ~BPF_FROM_USER;
 	if (unlikely((map_flags & ~BPF_F_LOCK) > BPF_EXIST) ||
 	    /* BPF_F_LOCK can only be used in a value with spin_lock */
 	    unlikely((map_flags & BPF_F_LOCK) &&
@@ -550,7 +552,7 @@  bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
 		if (err)
 			return ERR_PTR(err);
 
-		selem = bpf_selem_alloc(smap, owner, value, true, gfp_flags);
+		selem = bpf_selem_alloc(smap, owner, value, true, gfp_flags, from_user);
 		if (!selem)
 			return ERR_PTR(-ENOMEM);
 
@@ -575,8 +577,8 @@  bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
 		if (err)
 			return ERR_PTR(err);
 		if (old_sdata && selem_linked_to_storage_lockless(SELEM(old_sdata))) {
-			copy_map_value_locked(&smap->map, old_sdata->data,
-					      value, false);
+			copy_map_value_locked_user(&smap->map, old_sdata->data,
+						   value, false, from_user);
 			return old_sdata;
 		}
 	}
@@ -584,7 +586,7 @@  bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
 	/* A lookup has just been done before and concluded a new selem is
 	 * needed. The chance of an unnecessary alloc is unlikely.
 	 */
-	alloc_selem = selem = bpf_selem_alloc(smap, owner, value, true, gfp_flags);
+	alloc_selem = selem = bpf_selem_alloc(smap, owner, value, true, gfp_flags, from_user);
 	if (!alloc_selem)
 		return ERR_PTR(-ENOMEM);
 
@@ -607,8 +609,8 @@  bpf_local_storage_update(void *owner, struct bpf_local_storage_map *smap,
 		goto unlock;
 
 	if (old_sdata && (map_flags & BPF_F_LOCK)) {
-		copy_map_value_locked(&smap->map, old_sdata->data, value,
-				      false);
+		copy_map_value_locked_user(&smap->map, old_sdata->data, value,
+					   false, from_user);
 		selem = SELEM(old_sdata);
 		goto unlock;
 	}
diff --git a/kernel/bpf/helpers.c b/kernel/bpf/helpers.c
index d02ae323996b..4aef86209fdd 100644
--- a/kernel/bpf/helpers.c
+++ b/kernel/bpf/helpers.c
@@ -372,8 +372,8 @@  const struct bpf_func_proto bpf_spin_unlock_proto = {
 	.arg1_btf_id    = BPF_PTR_POISON,
 };
 
-void copy_map_value_locked(struct bpf_map *map, void *dst, void *src,
-			   bool lock_src)
+void copy_map_value_locked_user(struct bpf_map *map, void *dst, void *src,
+				bool lock_src, bool from_user)
 {
 	struct bpf_spin_lock *lock;
 
@@ -383,11 +383,17 @@  void copy_map_value_locked(struct bpf_map *map, void *dst, void *src,
 		lock = dst + map->record->spin_lock_off;
 	preempt_disable();
 	__bpf_spin_lock_irqsave(lock);
-	copy_map_value(map, dst, src);
+	copy_map_value_user(map, dst, src, from_user);
 	__bpf_spin_unlock_irqrestore(lock);
 	preempt_enable();
 }
 
+void copy_map_value_locked(struct bpf_map *map, void *dst, void *src,
+			   bool lock_src)
+{
+	copy_map_value_locked_user(map, dst, src, lock_src, false);
+}
+
 BPF_CALL_0(bpf_jiffies64)
 {
 	return get_jiffies_64();
diff --git a/kernel/bpf/local_storage.c b/kernel/bpf/local_storage.c
index 3969eb0382af..62a12fa8ce9e 100644
--- a/kernel/bpf/local_storage.c
+++ b/kernel/bpf/local_storage.c
@@ -147,7 +147,7 @@  static long cgroup_storage_update_elem(struct bpf_map *map, void *key,
 	struct bpf_cgroup_storage *storage;
 	struct bpf_storage_buffer *new;
 
-	if (unlikely(flags & ~(BPF_F_LOCK | BPF_EXIST)))
+	if (unlikely(flags & ~BPF_F_LOCK))
 		return -EINVAL;
 
 	if (unlikely((flags & BPF_F_LOCK) &&
diff --git a/kernel/bpf/syscall.c b/kernel/bpf/syscall.c
index 90a25307480e..eaa2a9d13265 100644
--- a/kernel/bpf/syscall.c
+++ b/kernel/bpf/syscall.c
@@ -155,8 +155,134 @@  static void maybe_wait_bpf_programs(struct bpf_map *map)
 		synchronize_rcu();
 }
 
-static int bpf_map_update_value(struct bpf_map *map, struct file *map_file,
-				void *key, void *value, __u64 flags)
+static void *trans_addr_pages(struct page **pages, int npages)
+{
+	if (npages == 1)
+		return page_address(pages[0]);
+	/* For multiple pages, we need to use vmap() to get a contiguous
+	 * virtual address range.
+	 */
+	return vmap(pages, npages, VM_MAP, PAGE_KERNEL);
+}
+
+#define KPTR_USER_MAX_PAGES 16
+
+static int bpf_obj_trans_pin_uaddr(struct btf_field *field, void **addr)
+{
+	const struct btf_type *t;
+	struct page *pages[KPTR_USER_MAX_PAGES];
+	void *ptr, *kern_addr;
+	u32 type_id, tsz;
+	int r, npages;
+
+	ptr = *addr;
+	type_id = field->kptr.btf_id;
+	t = btf_type_id_size(field->kptr.btf, &type_id, &tsz);
+	if (!t)
+		return -EINVAL;
+	if (tsz == 0) {
+		*addr = NULL;
+		return 0;
+	}
+
+	npages = (((intptr_t)ptr + tsz + ~PAGE_MASK) -
+		  ((intptr_t)ptr & PAGE_MASK)) >> PAGE_SHIFT;
+	if (npages > KPTR_USER_MAX_PAGES)
+		return -E2BIG;
+	r = pin_user_pages_fast((intptr_t)ptr & PAGE_MASK, npages, 0, pages);
+	if (r != npages)
+		return -EINVAL;
+	kern_addr = trans_addr_pages(pages, npages);
+	if (!kern_addr)
+		return -ENOMEM;
+	*addr = kern_addr + ((intptr_t)ptr & ~PAGE_MASK);
+	return 0;
+}
+
+void bpf_obj_unpin_uaddr(const struct btf_field *field, void *addr)
+{
+	struct page *pages[KPTR_USER_MAX_PAGES];
+	int npages, i;
+	u32 size, type_id;
+	void *ptr;
+
+	type_id = field->kptr.btf_id;
+	btf_type_id_size(field->kptr.btf, &type_id, &size);
+	if (size == 0)
+		return;
+
+	ptr = (void *)((intptr_t)addr & PAGE_MASK);
+	npages = (((intptr_t)addr + size + ~PAGE_MASK) - (intptr_t)ptr) >> PAGE_SHIFT;
+	for (i = 0; i < npages; i++) {
+		pages[i] = virt_to_page(ptr);
+		ptr += PAGE_SIZE;
+	}
+	if (npages > 1)
+		/* Paired with vmap() in trans_addr_pages() */
+		vunmap((void *)((intptr_t)addr & PAGE_MASK));
+	unpin_user_pages(pages, npages);
+}
+
+static int bpf_obj_trans_pin_uaddrs(struct btf_record *rec, void *src, u32 size)
+{
+	u32 next_off;
+	int i, err;
+
+	if (IS_ERR_OR_NULL(rec))
+		return 0;
+
+	if (!btf_record_has_field(rec, BPF_KPTR_USER))
+		return 0;
+
+	for (i = 0; i < rec->cnt; i++) {
+		if (rec->fields[i].type != BPF_KPTR_USER)
+			continue;
+
+		next_off = rec->fields[i].offset;
+		if (next_off + sizeof(void *) > size)
+			return -EINVAL;
+		err = bpf_obj_trans_pin_uaddr(&rec->fields[i], src + next_off);
+		if (!err)
+			continue;
+
+		/* Rollback */
+		for (i--; i >= 0; i--) {
+			if (rec->fields[i].type != BPF_KPTR_USER)
+				continue;
+			next_off = rec->fields[i].offset;
+			bpf_obj_unpin_uaddr(&rec->fields[i], *(void **)(src + next_off));
+			*(void **)(src + next_off) = NULL;
+		}
+
+		return err;
+	}
+
+	return 0;
+}
+
+static void bpf_obj_unpin_uaddrs(struct btf_record *rec, void *src)
+{
+	u32 next_off;
+	int i;
+
+	if (IS_ERR_OR_NULL(rec))
+		return;
+
+	if (!btf_record_has_field(rec, BPF_KPTR_USER))
+		return;
+
+	for (i = 0; i < rec->cnt; i++) {
+		if (rec->fields[i].type != BPF_KPTR_USER)
+			continue;
+
+		next_off = rec->fields[i].offset;
+		bpf_obj_unpin_uaddr(&rec->fields[i], *(void **)(src + next_off));
+		*(void **)(src + next_off) = NULL;
+	}
+}
+
+static int bpf_map_update_value_inner(struct bpf_map *map, struct file *map_file,
+				      void *key, void *value, __u64 flags)
 {
 	int err;
 
@@ -208,6 +334,29 @@  static int bpf_map_update_value(struct bpf_map *map, struct file *map_file,
 	return err;
 }
 
+static int bpf_map_update_value(struct bpf_map *map, struct file *map_file,
+				void *key, void *value, __u64 flags)
+{
+	int err;
+
+	if (flags & BPF_FROM_USER) {
+		/* Pin user memory can lead to context switch, so we need
+		 * to do it before potential RCU lock.
+		 */
+		err = bpf_obj_trans_pin_uaddrs(map->record, value,
+					       bpf_map_value_size(map));
+		if (err)
+			return err;
+	}
+
+	err = bpf_map_update_value_inner(map, map_file, key, value, flags);
+
+	if (err && (flags & BPF_FROM_USER))
+		bpf_obj_unpin_uaddrs(map->record, value);
+
+	return err;
+}
+
 static int bpf_map_copy_value(struct bpf_map *map, void *key, void *value,
 			      __u64 flags)
 {
@@ -714,6 +863,11 @@  void bpf_obj_free_fields(const struct btf_record *rec, void *obj)
 				field->kptr.dtor(xchgd_field);
 			}
 			break;
+		case BPF_KPTR_USER:
+			if (virt_addr_valid(*(void **)field_ptr))
+				bpf_obj_unpin_uaddr(field, *(void **)field_ptr);
+			*(void **)field_ptr = NULL;
+			break;
 		case BPF_LIST_HEAD:
 			if (WARN_ON_ONCE(rec->spin_lock_off < 0))
 				continue;
@@ -1155,6 +1309,12 @@  static int map_check_btf(struct bpf_map *map, struct bpf_token *token,
 					goto free_map_tab;
 				}
 				break;
+			case BPF_KPTR_USER:
+				if (map->map_type != BPF_MAP_TYPE_TASK_STORAGE) {
+					ret = -EOPNOTSUPP;
+					goto free_map_tab;
+				}
+				break;
 			case BPF_LIST_HEAD:
 			case BPF_RB_ROOT:
 				if (map->map_type != BPF_MAP_TYPE_HASH &&
@@ -1618,11 +1778,15 @@  static int map_update_elem(union bpf_attr *attr, bpfptr_t uattr)
 	struct bpf_map *map;
 	void *key, *value;
 	u32 value_size;
+	u64 extra_flags = 0;
 	struct fd f;
 	int err;
 
 	if (CHECK_ATTR(BPF_MAP_UPDATE_ELEM))
 		return -EINVAL;
+	/* Prevent userspace from setting any internal flags */
+	if (attr->flags & ~(BIT(BPF_MAP_UPDATE_FLAG_BITS) - 1))
+		return -EINVAL;
 
 	f = fdget(ufd);
 	map = __bpf_map_get(f);
@@ -1653,7 +1817,9 @@  static int map_update_elem(union bpf_attr *attr, bpfptr_t uattr)
 		goto free_key;
 	}
 
-	err = bpf_map_update_value(map, f.file, key, value, attr->flags);
+	if (map->map_type == BPF_MAP_TYPE_TASK_STORAGE)
+		extra_flags |= BPF_FROM_USER;
+	err = bpf_map_update_value(map, f.file, key, value, attr->flags | extra_flags);
 	if (!err)
 		maybe_wait_bpf_programs(map);
 
@@ -1852,6 +2018,7 @@  int generic_map_update_batch(struct bpf_map *map, struct file *map_file,
 	void __user *keys = u64_to_user_ptr(attr->batch.keys);
 	u32 value_size, cp, max_count;
 	void *key, *value;
+	u64 extra_flags = 0;
 	int err = 0;
 
 	if (attr->batch.elem_flags & ~BPF_F_LOCK)
@@ -1881,6 +2048,8 @@  int generic_map_update_batch(struct bpf_map *map, struct file *map_file,
 		return -ENOMEM;
 	}
 
+	if (map->map_type == BPF_MAP_TYPE_TASK_STORAGE)
+		extra_flags |= BPF_FROM_USER;
 	for (cp = 0; cp < max_count; cp++) {
 		err = -EFAULT;
 		if (copy_from_user(key, keys + cp * map->key_size,
@@ -1889,7 +2058,7 @@  int generic_map_update_batch(struct bpf_map *map, struct file *map_file,
 			break;
 
 		err = bpf_map_update_value(map, map_file, key, value,
-					   attr->batch.elem_flags);
+					   attr->batch.elem_flags | extra_flags);
 
 		if (err)
 			break;
diff --git a/net/core/bpf_sk_storage.c b/net/core/bpf_sk_storage.c
index bc01b3aa6b0f..db5281384e6a 100644
--- a/net/core/bpf_sk_storage.c
+++ b/net/core/bpf_sk_storage.c
@@ -137,7 +137,7 @@  bpf_sk_storage_clone_elem(struct sock *newsk,
 {
 	struct bpf_local_storage_elem *copy_selem;
 
-	copy_selem = bpf_selem_alloc(smap, newsk, NULL, true, GFP_ATOMIC);
+	copy_selem = bpf_selem_alloc(smap, newsk, NULL, true, GFP_ATOMIC, false);
 	if (!copy_selem)
 		return NULL;