@@ -9,9 +9,11 @@
#include <linux/bits.h>
/* Uprobes on this MM assume 32-bit code */
-#define MM_CONTEXT_UPROBE_IA32 BIT(0)
+#define MM_CONTEXT_UPROBE_IA32 BIT(0)
/* vsyscall page is accessible on this MM */
-#define MM_CONTEXT_HAS_VSYSCALL BIT(1)
+#define MM_CONTEXT_HAS_VSYSCALL BIT(1)
+/* Allow LAM and SVA coexisting */
+#define MM_CONTEXT_FORCE_TAGGED_SVA BIT(2)
/*
* x86 has arch-specific MMU state beyond what lives in mm_struct.
@@ -114,6 +114,12 @@ static inline void mm_reset_untag_mask(struct mm_struct *mm)
mm->context.untag_mask = -1UL;
}
+#define arch_pgtable_dma_compat arch_pgtable_dma_compat
+static inline bool arch_pgtable_dma_compat(struct mm_struct *mm)
+{
+ return !mm_lam_cr3_mask(mm) ||
+ (mm->context.flags & MM_CONTEXT_FORCE_TAGGED_SVA);
+}
#else
static inline unsigned long mm_lam_cr3_mask(struct mm_struct *mm)
@@ -23,5 +23,6 @@
#define ARCH_GET_UNTAG_MASK 0x4001
#define ARCH_ENABLE_TAGGED_ADDR 0x4002
#define ARCH_GET_MAX_TAG_BITS 0x4003
+#define ARCH_FORCE_TAGGED_SVA 0x4004
#endif /* _ASM_X86_PRCTL_H */
@@ -783,6 +783,12 @@ static int prctl_enable_tagged_addr(struct mm_struct *mm, unsigned long nr_bits)
goto out;
}
+ if (mm_valid_pasid(mm) &&
+ !(mm->context.flags & MM_CONTEXT_FORCE_TAGGED_SVA)) {
+ ret = -EBUSY;
+ goto out;
+ }
+
if (!nr_bits) {
ret = -EINVAL;
goto out;
@@ -893,6 +899,12 @@ long do_arch_prctl_64(struct task_struct *task, int option, unsigned long arg2)
(unsigned long __user *)arg2);
case ARCH_ENABLE_TAGGED_ADDR:
return prctl_enable_tagged_addr(task->mm, arg2);
+ case ARCH_FORCE_TAGGED_SVA:
+ if (mmap_write_lock_killable(task->mm))
+ return -EINTR;
+ task->mm->context.flags |= MM_CONTEXT_FORCE_TAGGED_SVA;
+ mmap_write_unlock(task->mm);
+ return 0;
case ARCH_GET_MAX_TAG_BITS:
if (!cpu_feature_enabled(X86_FEATURE_LAM))
return put_user(0, (unsigned long __user *)arg2);
@@ -2,6 +2,8 @@
/*
* Helpers for IOMMU drivers implementing SVA
*/
+#include <linux/mm.h>
+#include <linux/mmu_context.h>
#include <linux/mutex.h>
#include <linux/sched/mm.h>
@@ -31,6 +33,15 @@ int iommu_sva_alloc_pasid(struct mm_struct *mm, ioasid_t min, ioasid_t max)
min == 0 || max < min)
return -EINVAL;
+ /* Serialize against address tagging enabling */
+ if (mmap_write_lock_killable(mm))
+ return -EINTR;
+
+ if (!arch_pgtable_dma_compat(mm)) {
+ mmap_write_unlock(mm);
+ return -EBUSY;
+ }
+
mutex_lock(&iommu_sva_lock);
/* Is a PASID already associated with this mm? */
if (mm_valid_pasid(mm)) {
@@ -46,6 +57,7 @@ int iommu_sva_alloc_pasid(struct mm_struct *mm, ioasid_t min, ioasid_t max)
mm_pasid_set(mm, pasid);
out:
mutex_unlock(&iommu_sva_lock);
+ mmap_write_unlock(mm);
return ret;
}
EXPORT_SYMBOL_GPL(iommu_sva_alloc_pasid);
@@ -35,4 +35,11 @@ static inline unsigned long mm_untag_mask(struct mm_struct *mm)
}
#endif
+#ifndef arch_pgtable_dma_compat
+static inline bool arch_pgtable_dma_compat(struct mm_struct *mm)
+{
+ return true;
+}
+#endif
+
#endif