@@ -1411,6 +1411,18 @@ static int vfio_mdev_domain_type(struct device *dev, void *data)
return -EINVAL;
}
+static int vfio_parent_bus_type(struct device *dev, void *data)
+{
+ struct bus_type **bus = data;
+
+ if (*bus && *bus != dev->parent->bus)
+ return -EINVAL;
+
+ *bus = dev->parent->bus;
+
+ return 0;
+}
+
static int vfio_iommu_type1_attach_group(void *iommu_data,
struct iommu_group *iommu_group)
{
@@ -1458,6 +1470,7 @@ static int vfio_iommu_type1_attach_group(void *iommu_data,
enum mdev_domain_type type = 0;
symbol_put(mdev_bus_type);
+ mdev_bus = NULL;
/* Determine the domain type: */
ret = iommu_group_for_each_dev(iommu_group, &type,
@@ -1479,7 +1492,14 @@ static int vfio_iommu_type1_attach_group(void *iommu_data,
return 0;
case DOMAIN_TYPE_ATTACH_PARENT:
- /* FALLTHROUGH */
+ bus = NULL;
+ group->attach_parent = true;
+ /* Set @bus to bus type of the parent: */
+ ret = iommu_group_for_each_dev(iommu_group, &bus,
+ vfio_parent_bus_type);
+ if (ret)
+ goto out_free;
+ break;
default:
ret = -EINVAL;
goto out_free;