@@ -3003,6 +3003,45 @@ static inline int vm_fault_to_errno(vm_fault_t vm_fault, int foll_flags)
return 0;
}
+/*
+ * Indicates for which pages that are write-protected in the page table,
+ * whether GUP has to trigger unsharing via FAULT_FLAG_UNSHARE such that the
+ * GUP pin will remain consistent with the pages mapped into the page tables
+ * of the MM.
+ *
+ * Temporary unmapping of PageAnonExclusive() pages or clearing of
+ * PageAnonExclusive() has to protect against concurrent GUP:
+ * * Ordinary GUP: Using the PT lock
+ * * GUP-fast and fork(): mm->write_protect_seq
+ * * GUP-fast and KSM or temporary unmapping (swap, migration):
+ * clear/invalidate+flush of the page table entry
+ *
+ * Must be called with the (sub)page that's actually referenced via the
+ * page table entry, which might not necessarily be the head page for a
+ * PTE-mapped THP.
+ */
+static inline bool gup_must_unshare(unsigned int flags, struct page *page)
+{
+ /*
+ * FOLL_WRITE is implicitly handled correctly as the page table entry
+ * has to be writable -- and if it references (part of) an anonymous
+ * folio, that part is required to be marked exclusive.
+ */
+ if ((flags & (FOLL_WRITE | FOLL_PIN)) != FOLL_PIN)
+ return false;
+ /*
+ * Note: PageAnon(page) is stable until the page is actually getting
+ * freed.
+ */
+ if (!PageAnon(page))
+ return false;
+ /*
+ * Note that PageKsm() pages cannot be exclusive, and consequently,
+ * cannot get pinned.
+ */
+ return !PageAnonExclusive(page);
+}
+
typedef int (*pte_fn_t)(pte_t *pte, unsigned long addr, void *data);
extern int apply_to_page_range(struct mm_struct *mm, unsigned long address,
unsigned long size, pte_fn_t fn, void *data);
@@ -568,6 +568,10 @@ static struct page *follow_page_pte(struct vm_area_struct *vma,
}
}
+ if (!pte_write(pte) && gup_must_unshare(flags, page)) {
+ page = ERR_PTR(-EMLINK);
+ goto out;
+ }
/* try_grab_page() does nothing unless FOLL_GET or FOLL_PIN is set. */
if (unlikely(!try_grab_page(page, flags))) {
page = ERR_PTR(-ENOMEM);
@@ -820,6 +824,11 @@ static struct page *follow_p4d_mask(struct vm_area_struct *vma,
* When getting pages from ZONE_DEVICE memory, the @ctx->pgmap caches
* the device's dev_pagemap metadata to avoid repeating expensive lookups.
*
+ * When getting an anonymous page and the caller has to trigger unsharing
+ * of a shared anonymous page first, -EMLINK is returned. The caller should
+ * trigger a fault with FAULT_FLAG_UNSHARE set. Note that unsharing is only
+ * relevant with FOLL_PIN and !FOLL_WRITE.
+ *
* On output, the @ctx->page_mask is set according to the size of the page.
*
* Return: the mapped (struct page *), %NULL if no mapping exists, or
@@ -943,7 +952,8 @@ static int get_gate_page(struct mm_struct *mm, unsigned long address,
* is, *@locked will be set to 0 and -EBUSY returned.
*/
static int faultin_page(struct vm_area_struct *vma,
- unsigned long address, unsigned int *flags, int *locked)
+ unsigned long address, unsigned int *flags, bool unshare,
+ int *locked)
{
unsigned int fault_flags = 0;
vm_fault_t ret;
@@ -968,6 +978,11 @@ static int faultin_page(struct vm_area_struct *vma,
*/
fault_flags |= FAULT_FLAG_TRIED;
}
+ if (unshare) {
+ fault_flags |= FAULT_FLAG_UNSHARE;
+ /* FAULT_FLAG_WRITE and FAULT_FLAG_UNSHARE are incompatible */
+ VM_BUG_ON(fault_flags & FAULT_FLAG_WRITE);
+ }
ret = handle_mm_fault(vma, address, fault_flags, NULL);
if (ret & VM_FAULT_ERROR) {
@@ -1189,8 +1204,9 @@ static long __get_user_pages(struct mm_struct *mm,
cond_resched();
page = follow_page_mask(vma, start, foll_flags, &ctx);
- if (!page) {
- ret = faultin_page(vma, start, &foll_flags, locked);
+ if (!page || PTR_ERR(page) == -EMLINK) {
+ ret = faultin_page(vma, start, &foll_flags,
+ PTR_ERR(page) == -EMLINK, locked);
switch (ret) {
case 0:
goto retry;
@@ -2346,6 +2362,11 @@ static int gup_pte_range(pmd_t pmd, unsigned long addr, unsigned long end,
goto pte_unmap;
}
+ if (!pte_write(pte) && gup_must_unshare(flags, page)) {
+ put_compound_head(head, 1, flags);
+ goto pte_unmap;
+ }
+
VM_BUG_ON_PAGE(compound_head(page) != head, page);
/*
@@ -2529,6 +2550,11 @@ static int gup_hugepte(pte_t *ptep, unsigned long sz, unsigned long addr,
return 0;
}
+ if (!pte_write(pte) && gup_must_unshare(flags, head)) {
+ put_compound_head(head, refs, flags);
+ return 0;
+ }
+
*nr += refs;
SetPageReferenced(head);
return 1;
@@ -2589,6 +2615,11 @@ static int gup_huge_pmd(pmd_t orig, pmd_t *pmdp, unsigned long addr,
return 0;
}
+ if (!pmd_write(orig) && gup_must_unshare(flags, head)) {
+ put_compound_head(head, refs, flags);
+ return 0;
+ }
+
*nr += refs;
SetPageReferenced(head);
return 1;
@@ -2623,6 +2654,11 @@ static int gup_huge_pud(pud_t orig, pud_t *pudp, unsigned long addr,
return 0;
}
+ if (!pud_write(orig) && gup_must_unshare(flags, head)) {
+ put_compound_head(head, refs, flags);
+ return 0;
+ }
+
*nr += refs;
SetPageReferenced(head);
return 1;
@@ -1389,6 +1389,9 @@ struct page *follow_trans_huge_pmd(struct vm_area_struct *vma,
page = pmd_page(*pmd);
VM_BUG_ON_PAGE(!PageHead(page) && !is_zone_device_page(page), page);
+ if (!pmd_write(*pmd) && gup_must_unshare(flags, page))
+ return ERR_PTR(-EMLINK);
+
if (!try_grab_page(page, flags))
return ERR_PTR(-ENOMEM);
@@ -5956,6 +5956,25 @@ static void record_subpages_vmas(struct page *page, struct vm_area_struct *vma,
}
}
+static inline bool __follow_hugetlb_must_fault(unsigned int flags, pte_t *pte,
+ bool *unshare)
+{
+ pte_t pteval = huge_ptep_get(pte);
+
+ *unshare = false;
+ if (is_swap_pte(pteval))
+ return true;
+ if (huge_pte_write(pteval))
+ return false;
+ if (flags & FOLL_WRITE)
+ return true;
+ if (gup_must_unshare(flags, pte_page(pteval))) {
+ *unshare = true;
+ return true;
+ }
+ return false;
+}
+
long follow_hugetlb_page(struct mm_struct *mm, struct vm_area_struct *vma,
struct page **pages, struct vm_area_struct **vmas,
unsigned long *position, unsigned long *nr_pages,
@@ -5970,6 +5989,7 @@ long follow_hugetlb_page(struct mm_struct *mm, struct vm_area_struct *vma,
while (vaddr < vma->vm_end && remainder) {
pte_t *pte;
spinlock_t *ptl = NULL;
+ bool unshare = false;
int absent;
struct page *page;
@@ -6020,9 +6040,8 @@ long follow_hugetlb_page(struct mm_struct *mm, struct vm_area_struct *vma,
* both cases, and because we can't follow correct pages
* directly from any kind of swap entries.
*/
- if (absent || is_swap_pte(huge_ptep_get(pte)) ||
- ((flags & FOLL_WRITE) &&
- !huge_pte_write(huge_ptep_get(pte)))) {
+ if (absent ||
+ __follow_hugetlb_must_fault(flags, pte, &unshare)) {
vm_fault_t ret;
unsigned int fault_flags = 0;
@@ -6030,6 +6049,8 @@ long follow_hugetlb_page(struct mm_struct *mm, struct vm_area_struct *vma,
spin_unlock(ptl);
if (flags & FOLL_WRITE)
fault_flags |= FAULT_FLAG_WRITE;
+ else if (unshare)
+ fault_flags |= FAULT_FLAG_UNSHARE;
if (locked)
fault_flags |= FAULT_FLAG_ALLOW_RETRY |
FAULT_FLAG_KILLABLE;