@@ -1552,8 +1552,9 @@ static const struct file_operations vfio_device_fops = {
* vfio_file_iommu_group - Return the struct iommu_group for the vfio group file
* @file: VFIO group file
*
- * The returned iommu_group is valid as long as a ref is held on the file.
- * This function is deprecated, only the SPAPR path in kvm should call it.
+ * The returned iommu_group is valid as long as a ref is held on the file. This
+ * returns a reference on the group. This function is deprecated, only the SPAPR
+ * path in kvm should call it.
*/
struct iommu_group *vfio_file_iommu_group(struct file *file)
{
@@ -1564,6 +1565,7 @@ struct iommu_group *vfio_file_iommu_group(struct file *file)
if (!vfio_file_is_group(file))
return NULL;
+ iommu_group_ref_get(group->iommu_group);
return group->iommu_group;
}
EXPORT_SYMBOL_GPL(vfio_file_iommu_group);
@@ -24,6 +24,9 @@
struct kvm_vfio_group {
struct list_head node;
struct file *file;
+#ifdef CONFIG_SPAPR_TCE_IOMMU
+ struct iommu_group *iommu_group;
+#endif
};
struct kvm_vfio {
@@ -97,12 +100,12 @@ static struct iommu_group *kvm_vfio_file_iommu_group(struct file *file)
static void kvm_spapr_tce_release_vfio_group(struct kvm *kvm,
struct kvm_vfio_group *kvg)
{
- struct iommu_group *grp = kvm_vfio_file_iommu_group(kvg->file);
-
- if (WARN_ON_ONCE(!grp))
+ if (WARN_ON_ONCE(!kvg->iommu_group))
return;
- kvm_spapr_tce_release_iommu_group(kvm, grp);
+ kvm_spapr_tce_release_iommu_group(kvm, kvg->iommu_group);
+ iommu_group_put(kvg->iommu_group);
+ kvg->iommu_group = NULL;
}
#endif
@@ -252,19 +255,19 @@ static int kvm_vfio_group_set_spapr_tce(struct kvm_device *dev,
mutex_lock(&kv->lock);
list_for_each_entry(kvg, &kv->group_list, node) {
- struct iommu_group *grp;
-
if (kvg->file != f.file)
continue;
- grp = kvm_vfio_file_iommu_group(kvg->file);
- if (WARN_ON_ONCE(!grp)) {
- ret = -EIO;
- goto err_fdput;
+ if (!kvg->iommu_group) {
+ kvg->iommu_group = kvm_vfio_file_iommu_group(kvg->file);
+ if (WARN_ON_ONCE(!kvg->iommu_group)) {
+ ret = -EIO;
+ goto err_fdput;
+ }
}
ret = kvm_spapr_tce_attach_iommu_group(dev->kvm, param.tablefd,
- grp);
+ kvg->iommu_group);
break;
}