@@ -734,6 +734,51 @@ bool vfio_assert_device_open(struct vfio_device *device)
return !WARN_ON_ONCE(!READ_ONCE(device->open_count));
}
+static int vfio_device_first_open(struct vfio_device *device)
+{
+ int ret;
+
+ lockdep_assert_held(&device->dev_set->lock);
+
+ if (!try_module_get(device->dev->driver->owner))
+ return -ENODEV;
+
+ /*
+ * Here we pass the KVM pointer with the group under the lock. If the
+ * device driver will use it, it must obtain a reference and release it
+ * during close_device.
+ */
+ mutex_lock(&device->group->group_lock);
+ device->kvm = device->group->kvm;
+ if (device->ops->open_device) {
+ ret = device->ops->open_device(device);
+ if (ret)
+ goto err_module_put;
+ }
+ vfio_device_container_register(device);
+ mutex_unlock(&device->group->group_lock);
+ return 0;
+
+err_module_put:
+ device->kvm = NULL;
+ mutex_unlock(&device->group->group_lock);
+ module_put(device->dev->driver->owner);
+ return ret;
+}
+
+static void vfio_device_last_close(struct vfio_device *device)
+{
+ lockdep_assert_held(&device->dev_set->lock);
+
+ mutex_lock(&device->group->group_lock);
+ vfio_device_container_unregister(device);
+ if (device->ops->close_device)
+ device->ops->close_device(device);
+ device->kvm = NULL;
+ mutex_unlock(&device->group->group_lock);
+ module_put(device->dev->driver->owner);
+}
+
static struct file *vfio_device_open(struct vfio_device *device)
{
struct file *filep;
@@ -745,29 +790,12 @@ static struct file *vfio_device_open(struct vfio_device *device)
if (ret)
return ERR_PTR(ret);
- if (!try_module_get(device->dev->driver->owner)) {
- ret = -ENODEV;
- goto err_unassign_container;
- }
-
mutex_lock(&device->dev_set->lock);
device->open_count++;
if (device->open_count == 1) {
- /*
- * Here we pass the KVM pointer with the group under the read
- * lock. If the device driver will use it, it must obtain a
- * reference and release it during close_device.
- */
- mutex_lock(&device->group->group_lock);
- device->kvm = device->group->kvm;
-
- if (device->ops->open_device) {
- ret = device->ops->open_device(device);
- if (ret)
- goto err_undo_count;
- }
- vfio_device_container_register(device);
- mutex_unlock(&device->group->group_lock);
+ ret = vfio_device_first_open(device);
+ if (ret)
+ goto err_unassign_container;
}
mutex_unlock(&device->dev_set->lock);
@@ -800,20 +828,11 @@ static struct file *vfio_device_open(struct vfio_device *device)
err_close_device:
mutex_lock(&device->dev_set->lock);
- mutex_lock(&device->group->group_lock);
- if (device->open_count == 1 && device->ops->close_device) {
- device->ops->close_device(device);
-
- vfio_device_container_unregister(device);
- }
-err_undo_count:
- mutex_unlock(&device->group->group_lock);
+ if (device->open_count == 1)
+ vfio_device_last_close(device);
+err_unassign_container:
device->open_count--;
- if (device->open_count == 0 && device->kvm)
- device->kvm = NULL;
mutex_unlock(&device->dev_set->lock);
- module_put(device->dev->driver->owner);
-err_unassign_container:
vfio_device_unassign_container(device);
return ERR_PTR(ret);
}
@@ -1016,19 +1035,11 @@ static int vfio_device_fops_release(struct inode *inode, struct file *filep)
mutex_lock(&device->dev_set->lock);
vfio_assert_device_open(device);
- mutex_lock(&device->group->group_lock);
- if (device->open_count == 1 && device->ops->close_device)
- device->ops->close_device(device);
-
- vfio_device_container_unregister(device);
- mutex_unlock(&device->group->group_lock);
+ if (device->open_count == 1)
+ vfio_device_last_close(device);
device->open_count--;
- if (device->open_count == 0)
- device->kvm = NULL;
mutex_unlock(&device->dev_set->lock);
- module_put(device->dev->driver->owner);
-
vfio_device_unassign_container(device);
vfio_device_put_registration(device);