diff mbox series

[V1,mlx5-next,2/4] vfio/mlx5: Manage the VF attach/detach callback from the PF

Message ID 20220508131053.241347-3-yishaih@nvidia.com (mailing list archive)
State New, archived
Headers show
Series Improve mlx5 live migration driver | expand

Commit Message

Yishai Hadas May 8, 2022, 1:10 p.m. UTC
Manage the VF attach/detach callback from the PF.

This lets the driver to enable parallel VFs migration as will be
introduced in the next patch.

As part of this, reorganize the VF is migratable code to be in a
separate function and rename it to be set_migratable() to match its
functionality.

Signed-off-by: Yishai Hadas <yishaih@nvidia.com>
Signed-off-by: Leon Romanovsky <leonro@nvidia.com>
---
 drivers/vfio/pci/mlx5/cmd.c  | 63 ++++++++++++++++++++++++++++++++++++
 drivers/vfio/pci/mlx5/cmd.h  | 22 +++++++++++++
 drivers/vfio/pci/mlx5/main.c | 40 ++++-------------------
 3 files changed, 91 insertions(+), 34 deletions(-)

Comments

Alex Williamson May 9, 2022, 5:29 p.m. UTC | #1
On Sun, 8 May 2022 16:10:51 +0300
Yishai Hadas <yishaih@nvidia.com> wrote:

> Manage the VF attach/detach callback from the PF.
> 
> This lets the driver to enable parallel VFs migration as will be
> introduced in the next patch.
> 
> As part of this, reorganize the VF is migratable code to be in a
> separate function and rename it to be set_migratable() to match its
> functionality.
> 
> Signed-off-by: Yishai Hadas <yishaih@nvidia.com>
> Signed-off-by: Leon Romanovsky <leonro@nvidia.com>
> ---
>  drivers/vfio/pci/mlx5/cmd.c  | 63 ++++++++++++++++++++++++++++++++++++
>  drivers/vfio/pci/mlx5/cmd.h  | 22 +++++++++++++
>  drivers/vfio/pci/mlx5/main.c | 40 ++++-------------------
>  3 files changed, 91 insertions(+), 34 deletions(-)
> 
> diff --git a/drivers/vfio/pci/mlx5/cmd.c b/drivers/vfio/pci/mlx5/cmd.c
> index 5c9f9218cc1d..5031978ae63a 100644
> --- a/drivers/vfio/pci/mlx5/cmd.c
> +++ b/drivers/vfio/pci/mlx5/cmd.c
> @@ -71,6 +71,69 @@ int mlx5vf_cmd_query_vhca_migration_state(struct pci_dev *pdev, u16 vhca_id,
>  	return ret;
>  }
>  
> +static int mlx5fv_vf_event(struct notifier_block *nb,
> +			   unsigned long event, void *data)
> +{
> +	struct mlx5vf_pci_core_device *mvdev =
> +		container_of(nb, struct mlx5vf_pci_core_device, nb);
> +
> +	mutex_lock(&mvdev->state_mutex);
> +	switch (event) {
> +	case MLX5_PF_NOTIFY_ENABLE_VF:
> +		mvdev->mdev_detach = false;
> +		break;
> +	case MLX5_PF_NOTIFY_DISABLE_VF:
> +		mvdev->mdev_detach = true;
> +		break;
> +	default:
> +		break;
> +	}
> +	mlx5vf_state_mutex_unlock(mvdev);
> +	return 0;
> +}
> +
> +void mlx5vf_cmd_remove_migratable(struct mlx5vf_pci_core_device *mvdev)
> +{
> +	mlx5_sriov_blocking_notifier_unregister(mvdev->mdev, mvdev->vf_id,
> +						&mvdev->nb);
> +}
> +
> +void mlx5vf_cmd_set_migratable(struct mlx5vf_pci_core_device *mvdev)
> +{
> +	struct pci_dev *pdev = mvdev->core_device.pdev;
> +	int ret;
> +
> +	if (!pdev->is_virtfn)
> +		return;
> +
> +	mvdev->mdev = mlx5_vf_get_core_dev(pdev);
> +	if (!mvdev->mdev)
> +		return;
> +
> +	if (!MLX5_CAP_GEN(mvdev->mdev, migration))
> +		goto end;
> +
> +	mvdev->vf_id = pci_iov_vf_id(pdev);
> +	if (mvdev->vf_id < 0)
> +		goto end;
> +
> +	mutex_init(&mvdev->state_mutex);
> +	spin_lock_init(&mvdev->reset_lock);
> +	mvdev->nb.notifier_call = mlx5fv_vf_event;
> +	ret = mlx5_sriov_blocking_notifier_register(mvdev->mdev, mvdev->vf_id,
> +						    &mvdev->nb);
> +	if (ret)
> +		goto end;
> +
> +	mvdev->migrate_cap = 1;
> +	mvdev->core_device.vdev.migration_flags =
> +		VFIO_MIGRATION_STOP_COPY |
> +		VFIO_MIGRATION_P2P;
> +
> +end:
> +	mlx5_vf_put_core_dev(mvdev->mdev);
> +}
> +
>  int mlx5vf_cmd_get_vhca_id(struct pci_dev *pdev, u16 function_id, u16 *vhca_id)
>  {
>  	struct mlx5_core_dev *mdev = mlx5_vf_get_core_dev(pdev);
> diff --git a/drivers/vfio/pci/mlx5/cmd.h b/drivers/vfio/pci/mlx5/cmd.h
> index 1392a11a9cc0..340a06b98007 100644
> --- a/drivers/vfio/pci/mlx5/cmd.h
> +++ b/drivers/vfio/pci/mlx5/cmd.h
> @@ -7,6 +7,7 @@
>  #define MLX5_VFIO_CMD_H
>  
>  #include <linux/kernel.h>
> +#include <linux/vfio_pci_core.h>
>  #include <linux/mlx5/driver.h>
>  
>  struct mlx5_vf_migration_file {
> @@ -24,13 +25,34 @@ struct mlx5_vf_migration_file {
>  	unsigned long last_offset;
>  };
>  
> +struct mlx5vf_pci_core_device {
> +	struct vfio_pci_core_device core_device;
> +	int vf_id;
> +	u16 vhca_id;
> +	u8 migrate_cap:1;
> +	u8 deferred_reset:1;
> +	/* protect migration state */
> +	struct mutex state_mutex;
> +	enum vfio_device_mig_state mig_state;
> +	/* protect the reset_done flow */
> +	spinlock_t reset_lock;
> +	struct mlx5_vf_migration_file *resuming_migf;
> +	struct mlx5_vf_migration_file *saving_migf;
> +	struct notifier_block nb;
> +	struct mlx5_core_dev *mdev;
> +	u8 mdev_detach:1;


This should be packed with the other bit fields, there's plenty of
space there.


> +};
> +
>  int mlx5vf_cmd_suspend_vhca(struct pci_dev *pdev, u16 vhca_id, u16 op_mod);
>  int mlx5vf_cmd_resume_vhca(struct pci_dev *pdev, u16 vhca_id, u16 op_mod);
>  int mlx5vf_cmd_query_vhca_migration_state(struct pci_dev *pdev, u16 vhca_id,
>  					  size_t *state_size);
>  int mlx5vf_cmd_get_vhca_id(struct pci_dev *pdev, u16 function_id, u16 *vhca_id);
> +void mlx5vf_cmd_set_migratable(struct mlx5vf_pci_core_device *mvdev);
> +void mlx5vf_cmd_remove_migratable(struct mlx5vf_pci_core_device *mvdev);
>  int mlx5vf_cmd_save_vhca_state(struct pci_dev *pdev, u16 vhca_id,
>  			       struct mlx5_vf_migration_file *migf);
>  int mlx5vf_cmd_load_vhca_state(struct pci_dev *pdev, u16 vhca_id,
>  			       struct mlx5_vf_migration_file *migf);
> +void mlx5vf_state_mutex_unlock(struct mlx5vf_pci_core_device *mvdev);
>  #endif /* MLX5_VFIO_CMD_H */
> diff --git a/drivers/vfio/pci/mlx5/main.c b/drivers/vfio/pci/mlx5/main.c
> index bbec5d288fee..9716c87e31f9 100644
> --- a/drivers/vfio/pci/mlx5/main.c
> +++ b/drivers/vfio/pci/mlx5/main.c
> @@ -17,7 +17,6 @@
>  #include <linux/uaccess.h>
>  #include <linux/vfio.h>
>  #include <linux/sched/mm.h>
> -#include <linux/vfio_pci_core.h>
>  #include <linux/anon_inodes.h>
>  
>  #include "cmd.h"
> @@ -25,20 +24,6 @@
>  /* Arbitrary to prevent userspace from consuming endless memory */
>  #define MAX_MIGRATION_SIZE (512*1024*1024)
>  
> -struct mlx5vf_pci_core_device {
> -	struct vfio_pci_core_device core_device;
> -	u16 vhca_id;
> -	u8 migrate_cap:1;
> -	u8 deferred_reset:1;
> -	/* protect migration state */
> -	struct mutex state_mutex;
> -	enum vfio_device_mig_state mig_state;
> -	/* protect the reset_done flow */
> -	spinlock_t reset_lock;
> -	struct mlx5_vf_migration_file *resuming_migf;
> -	struct mlx5_vf_migration_file *saving_migf;
> -};
> -
>  static struct page *
>  mlx5vf_get_migration_page(struct mlx5_vf_migration_file *migf,
>  			  unsigned long offset)
> @@ -444,7 +429,7 @@ mlx5vf_pci_step_device_state_locked(struct mlx5vf_pci_core_device *mvdev,
>   * This function is called in all state_mutex unlock cases to
>   * handle a 'deferred_reset' if exists.
>   */
> -static void mlx5vf_state_mutex_unlock(struct mlx5vf_pci_core_device *mvdev)
> +void mlx5vf_state_mutex_unlock(struct mlx5vf_pci_core_device *mvdev)
>  {
>  again:
>  	spin_lock(&mvdev->reset_lock);
> @@ -596,24 +581,7 @@ static int mlx5vf_pci_probe(struct pci_dev *pdev,
>  	if (!mvdev)
>  		return -ENOMEM;
>  	vfio_pci_core_init_device(&mvdev->core_device, pdev, &mlx5vf_pci_ops);
> -
> -	if (pdev->is_virtfn) {
> -		struct mlx5_core_dev *mdev =
> -			mlx5_vf_get_core_dev(pdev);
> -
> -		if (mdev) {
> -			if (MLX5_CAP_GEN(mdev, migration)) {
> -				mvdev->migrate_cap = 1;
> -				mvdev->core_device.vdev.migration_flags =
> -					VFIO_MIGRATION_STOP_COPY |
> -					VFIO_MIGRATION_P2P;
> -				mutex_init(&mvdev->state_mutex);
> -				spin_lock_init(&mvdev->reset_lock);
> -			}
> -			mlx5_vf_put_core_dev(mdev);
> -		}
> -	}
> -
> +	mlx5vf_cmd_set_migratable(mvdev);
>  	ret = vfio_pci_core_register_device(&mvdev->core_device);
>  	if (ret)
>  		goto out_free;
> @@ -622,6 +590,8 @@ static int mlx5vf_pci_probe(struct pci_dev *pdev,
>  	return 0;
>  
>  out_free:
> +	if (mvdev->migrate_cap)
> +		mlx5vf_cmd_remove_migratable(mvdev);
>  	vfio_pci_core_uninit_device(&mvdev->core_device);
>  	kfree(mvdev);
>  	return ret;
> @@ -632,6 +602,8 @@ static void mlx5vf_pci_remove(struct pci_dev *pdev)
>  	struct mlx5vf_pci_core_device *mvdev = dev_get_drvdata(&pdev->dev);
>  
>  	vfio_pci_core_unregister_device(&mvdev->core_device);
> +	if (mvdev->migrate_cap)
> +		mlx5vf_cmd_remove_migratable(mvdev);
>  	vfio_pci_core_uninit_device(&mvdev->core_device);
>  	kfree(mvdev);
>  }


Personally, I'd push the test into the function, ie.

void mlx5vf_cmd_remove_migratable(struct mlx5vf_pci_core_device *mvdev)
{
	if (!mvdev->migrate_cap)
		return;

	...
}

But it's clearly functional this way as well.  Please do fix the bit
field packing though.  Thanks,

Alex
Yishai Hadas May 10, 2022, 8:23 a.m. UTC | #2
On 09/05/2022 20:29, Alex Williamson wrote:
> On Sun, 8 May 2022 16:10:51 +0300
> Yishai Hadas <yishaih@nvidia.com> wrote:
>
>> Manage the VF attach/detach callback from the PF.
>>
>> This lets the driver to enable parallel VFs migration as will be
>> introduced in the next patch.
>>
>> As part of this, reorganize the VF is migratable code to be in a
>> separate function and rename it to be set_migratable() to match its
>> functionality.
>>
>> Signed-off-by: Yishai Hadas <yishaih@nvidia.com>
>> Signed-off-by: Leon Romanovsky <leonro@nvidia.com>
>> ---
>>   drivers/vfio/pci/mlx5/cmd.c  | 63 ++++++++++++++++++++++++++++++++++++
>>   drivers/vfio/pci/mlx5/cmd.h  | 22 +++++++++++++
>>   drivers/vfio/pci/mlx5/main.c | 40 ++++-------------------
>>   3 files changed, 91 insertions(+), 34 deletions(-)
>>
>> diff --git a/drivers/vfio/pci/mlx5/cmd.c b/drivers/vfio/pci/mlx5/cmd.c
>> index 5c9f9218cc1d..5031978ae63a 100644
>> --- a/drivers/vfio/pci/mlx5/cmd.c
>> +++ b/drivers/vfio/pci/mlx5/cmd.c
>> @@ -71,6 +71,69 @@ int mlx5vf_cmd_query_vhca_migration_state(struct pci_dev *pdev, u16 vhca_id,
>>   	return ret;
>>   }
>>   
>> +static int mlx5fv_vf_event(struct notifier_block *nb,
>> +			   unsigned long event, void *data)
>> +{
>> +	struct mlx5vf_pci_core_device *mvdev =
>> +		container_of(nb, struct mlx5vf_pci_core_device, nb);
>> +
>> +	mutex_lock(&mvdev->state_mutex);
>> +	switch (event) {
>> +	case MLX5_PF_NOTIFY_ENABLE_VF:
>> +		mvdev->mdev_detach = false;
>> +		break;
>> +	case MLX5_PF_NOTIFY_DISABLE_VF:
>> +		mvdev->mdev_detach = true;
>> +		break;
>> +	default:
>> +		break;
>> +	}
>> +	mlx5vf_state_mutex_unlock(mvdev);
>> +	return 0;
>> +}
>> +
>> +void mlx5vf_cmd_remove_migratable(struct mlx5vf_pci_core_device *mvdev)
>> +{
>> +	mlx5_sriov_blocking_notifier_unregister(mvdev->mdev, mvdev->vf_id,
>> +						&mvdev->nb);
>> +}
>> +
>> +void mlx5vf_cmd_set_migratable(struct mlx5vf_pci_core_device *mvdev)
>> +{
>> +	struct pci_dev *pdev = mvdev->core_device.pdev;
>> +	int ret;
>> +
>> +	if (!pdev->is_virtfn)
>> +		return;
>> +
>> +	mvdev->mdev = mlx5_vf_get_core_dev(pdev);
>> +	if (!mvdev->mdev)
>> +		return;
>> +
>> +	if (!MLX5_CAP_GEN(mvdev->mdev, migration))
>> +		goto end;
>> +
>> +	mvdev->vf_id = pci_iov_vf_id(pdev);
>> +	if (mvdev->vf_id < 0)
>> +		goto end;
>> +
>> +	mutex_init(&mvdev->state_mutex);
>> +	spin_lock_init(&mvdev->reset_lock);
>> +	mvdev->nb.notifier_call = mlx5fv_vf_event;
>> +	ret = mlx5_sriov_blocking_notifier_register(mvdev->mdev, mvdev->vf_id,
>> +						    &mvdev->nb);
>> +	if (ret)
>> +		goto end;
>> +
>> +	mvdev->migrate_cap = 1;
>> +	mvdev->core_device.vdev.migration_flags =
>> +		VFIO_MIGRATION_STOP_COPY |
>> +		VFIO_MIGRATION_P2P;
>> +
>> +end:
>> +	mlx5_vf_put_core_dev(mvdev->mdev);
>> +}
>> +
>>   int mlx5vf_cmd_get_vhca_id(struct pci_dev *pdev, u16 function_id, u16 *vhca_id)
>>   {
>>   	struct mlx5_core_dev *mdev = mlx5_vf_get_core_dev(pdev);
>> diff --git a/drivers/vfio/pci/mlx5/cmd.h b/drivers/vfio/pci/mlx5/cmd.h
>> index 1392a11a9cc0..340a06b98007 100644
>> --- a/drivers/vfio/pci/mlx5/cmd.h
>> +++ b/drivers/vfio/pci/mlx5/cmd.h
>> @@ -7,6 +7,7 @@
>>   #define MLX5_VFIO_CMD_H
>>   
>>   #include <linux/kernel.h>
>> +#include <linux/vfio_pci_core.h>
>>   #include <linux/mlx5/driver.h>
>>   
>>   struct mlx5_vf_migration_file {
>> @@ -24,13 +25,34 @@ struct mlx5_vf_migration_file {
>>   	unsigned long last_offset;
>>   };
>>   
>> +struct mlx5vf_pci_core_device {
>> +	struct vfio_pci_core_device core_device;
>> +	int vf_id;
>> +	u16 vhca_id;
>> +	u8 migrate_cap:1;
>> +	u8 deferred_reset:1;
>> +	/* protect migration state */
>> +	struct mutex state_mutex;
>> +	enum vfio_device_mig_state mig_state;
>> +	/* protect the reset_done flow */
>> +	spinlock_t reset_lock;
>> +	struct mlx5_vf_migration_file *resuming_migf;
>> +	struct mlx5_vf_migration_file *saving_migf;
>> +	struct notifier_block nb;
>> +	struct mlx5_core_dev *mdev;
>> +	u8 mdev_detach:1;
>
> This should be packed with the other bit fields, there's plenty of
> space there.
>
Sure, will be part of V2.

>> +};
>> +
>>   int mlx5vf_cmd_suspend_vhca(struct pci_dev *pdev, u16 vhca_id, u16 op_mod);
>>   int mlx5vf_cmd_resume_vhca(struct pci_dev *pdev, u16 vhca_id, u16 op_mod);
>>   int mlx5vf_cmd_query_vhca_migration_state(struct pci_dev *pdev, u16 vhca_id,
>>   					  size_t *state_size);
>>   int mlx5vf_cmd_get_vhca_id(struct pci_dev *pdev, u16 function_id, u16 *vhca_id);
>> +void mlx5vf_cmd_set_migratable(struct mlx5vf_pci_core_device *mvdev);
>> +void mlx5vf_cmd_remove_migratable(struct mlx5vf_pci_core_device *mvdev);
>>   int mlx5vf_cmd_save_vhca_state(struct pci_dev *pdev, u16 vhca_id,
>>   			       struct mlx5_vf_migration_file *migf);
>>   int mlx5vf_cmd_load_vhca_state(struct pci_dev *pdev, u16 vhca_id,
>>   			       struct mlx5_vf_migration_file *migf);
>> +void mlx5vf_state_mutex_unlock(struct mlx5vf_pci_core_device *mvdev);
>>   #endif /* MLX5_VFIO_CMD_H */
>> diff --git a/drivers/vfio/pci/mlx5/main.c b/drivers/vfio/pci/mlx5/main.c
>> index bbec5d288fee..9716c87e31f9 100644
>> --- a/drivers/vfio/pci/mlx5/main.c
>> +++ b/drivers/vfio/pci/mlx5/main.c
>> @@ -17,7 +17,6 @@
>>   #include <linux/uaccess.h>
>>   #include <linux/vfio.h>
>>   #include <linux/sched/mm.h>
>> -#include <linux/vfio_pci_core.h>
>>   #include <linux/anon_inodes.h>
>>   
>>   #include "cmd.h"
>> @@ -25,20 +24,6 @@
>>   /* Arbitrary to prevent userspace from consuming endless memory */
>>   #define MAX_MIGRATION_SIZE (512*1024*1024)
>>   
>> -struct mlx5vf_pci_core_device {
>> -	struct vfio_pci_core_device core_device;
>> -	u16 vhca_id;
>> -	u8 migrate_cap:1;
>> -	u8 deferred_reset:1;
>> -	/* protect migration state */
>> -	struct mutex state_mutex;
>> -	enum vfio_device_mig_state mig_state;
>> -	/* protect the reset_done flow */
>> -	spinlock_t reset_lock;
>> -	struct mlx5_vf_migration_file *resuming_migf;
>> -	struct mlx5_vf_migration_file *saving_migf;
>> -};
>> -
>>   static struct page *
>>   mlx5vf_get_migration_page(struct mlx5_vf_migration_file *migf,
>>   			  unsigned long offset)
>> @@ -444,7 +429,7 @@ mlx5vf_pci_step_device_state_locked(struct mlx5vf_pci_core_device *mvdev,
>>    * This function is called in all state_mutex unlock cases to
>>    * handle a 'deferred_reset' if exists.
>>    */
>> -static void mlx5vf_state_mutex_unlock(struct mlx5vf_pci_core_device *mvdev)
>> +void mlx5vf_state_mutex_unlock(struct mlx5vf_pci_core_device *mvdev)
>>   {
>>   again:
>>   	spin_lock(&mvdev->reset_lock);
>> @@ -596,24 +581,7 @@ static int mlx5vf_pci_probe(struct pci_dev *pdev,
>>   	if (!mvdev)
>>   		return -ENOMEM;
>>   	vfio_pci_core_init_device(&mvdev->core_device, pdev, &mlx5vf_pci_ops);
>> -
>> -	if (pdev->is_virtfn) {
>> -		struct mlx5_core_dev *mdev =
>> -			mlx5_vf_get_core_dev(pdev);
>> -
>> -		if (mdev) {
>> -			if (MLX5_CAP_GEN(mdev, migration)) {
>> -				mvdev->migrate_cap = 1;
>> -				mvdev->core_device.vdev.migration_flags =
>> -					VFIO_MIGRATION_STOP_COPY |
>> -					VFIO_MIGRATION_P2P;
>> -				mutex_init(&mvdev->state_mutex);
>> -				spin_lock_init(&mvdev->reset_lock);
>> -			}
>> -			mlx5_vf_put_core_dev(mdev);
>> -		}
>> -	}
>> -
>> +	mlx5vf_cmd_set_migratable(mvdev);
>>   	ret = vfio_pci_core_register_device(&mvdev->core_device);
>>   	if (ret)
>>   		goto out_free;
>> @@ -622,6 +590,8 @@ static int mlx5vf_pci_probe(struct pci_dev *pdev,
>>   	return 0;
>>   
>>   out_free:
>> +	if (mvdev->migrate_cap)
>> +		mlx5vf_cmd_remove_migratable(mvdev);
>>   	vfio_pci_core_uninit_device(&mvdev->core_device);
>>   	kfree(mvdev);
>>   	return ret;
>> @@ -632,6 +602,8 @@ static void mlx5vf_pci_remove(struct pci_dev *pdev)
>>   	struct mlx5vf_pci_core_device *mvdev = dev_get_drvdata(&pdev->dev);
>>   
>>   	vfio_pci_core_unregister_device(&mvdev->core_device);
>> +	if (mvdev->migrate_cap)
>> +		mlx5vf_cmd_remove_migratable(mvdev);
>>   	vfio_pci_core_uninit_device(&mvdev->core_device);
>>   	kfree(mvdev);
>>   }
>
> Personally, I'd push the test into the function, ie.
>
> void mlx5vf_cmd_remove_migratable(struct mlx5vf_pci_core_device *mvdev)
> {
> 	if (!mvdev->migrate_cap)
> 		return;
>
> 	...
> }


Makes sense, this will keep the caller code symmetric/clean for both the 
set/remove calls.

Will be part of V2.

> But it's clearly functional this way as well.  Please do fix the bit
> field packing though.  Thanks,
>
> Alex
>
diff mbox series

Patch

diff --git a/drivers/vfio/pci/mlx5/cmd.c b/drivers/vfio/pci/mlx5/cmd.c
index 5c9f9218cc1d..5031978ae63a 100644
--- a/drivers/vfio/pci/mlx5/cmd.c
+++ b/drivers/vfio/pci/mlx5/cmd.c
@@ -71,6 +71,69 @@  int mlx5vf_cmd_query_vhca_migration_state(struct pci_dev *pdev, u16 vhca_id,
 	return ret;
 }
 
+static int mlx5fv_vf_event(struct notifier_block *nb,
+			   unsigned long event, void *data)
+{
+	struct mlx5vf_pci_core_device *mvdev =
+		container_of(nb, struct mlx5vf_pci_core_device, nb);
+
+	mutex_lock(&mvdev->state_mutex);
+	switch (event) {
+	case MLX5_PF_NOTIFY_ENABLE_VF:
+		mvdev->mdev_detach = false;
+		break;
+	case MLX5_PF_NOTIFY_DISABLE_VF:
+		mvdev->mdev_detach = true;
+		break;
+	default:
+		break;
+	}
+	mlx5vf_state_mutex_unlock(mvdev);
+	return 0;
+}
+
+void mlx5vf_cmd_remove_migratable(struct mlx5vf_pci_core_device *mvdev)
+{
+	mlx5_sriov_blocking_notifier_unregister(mvdev->mdev, mvdev->vf_id,
+						&mvdev->nb);
+}
+
+void mlx5vf_cmd_set_migratable(struct mlx5vf_pci_core_device *mvdev)
+{
+	struct pci_dev *pdev = mvdev->core_device.pdev;
+	int ret;
+
+	if (!pdev->is_virtfn)
+		return;
+
+	mvdev->mdev = mlx5_vf_get_core_dev(pdev);
+	if (!mvdev->mdev)
+		return;
+
+	if (!MLX5_CAP_GEN(mvdev->mdev, migration))
+		goto end;
+
+	mvdev->vf_id = pci_iov_vf_id(pdev);
+	if (mvdev->vf_id < 0)
+		goto end;
+
+	mutex_init(&mvdev->state_mutex);
+	spin_lock_init(&mvdev->reset_lock);
+	mvdev->nb.notifier_call = mlx5fv_vf_event;
+	ret = mlx5_sriov_blocking_notifier_register(mvdev->mdev, mvdev->vf_id,
+						    &mvdev->nb);
+	if (ret)
+		goto end;
+
+	mvdev->migrate_cap = 1;
+	mvdev->core_device.vdev.migration_flags =
+		VFIO_MIGRATION_STOP_COPY |
+		VFIO_MIGRATION_P2P;
+
+end:
+	mlx5_vf_put_core_dev(mvdev->mdev);
+}
+
 int mlx5vf_cmd_get_vhca_id(struct pci_dev *pdev, u16 function_id, u16 *vhca_id)
 {
 	struct mlx5_core_dev *mdev = mlx5_vf_get_core_dev(pdev);
diff --git a/drivers/vfio/pci/mlx5/cmd.h b/drivers/vfio/pci/mlx5/cmd.h
index 1392a11a9cc0..340a06b98007 100644
--- a/drivers/vfio/pci/mlx5/cmd.h
+++ b/drivers/vfio/pci/mlx5/cmd.h
@@ -7,6 +7,7 @@ 
 #define MLX5_VFIO_CMD_H
 
 #include <linux/kernel.h>
+#include <linux/vfio_pci_core.h>
 #include <linux/mlx5/driver.h>
 
 struct mlx5_vf_migration_file {
@@ -24,13 +25,34 @@  struct mlx5_vf_migration_file {
 	unsigned long last_offset;
 };
 
+struct mlx5vf_pci_core_device {
+	struct vfio_pci_core_device core_device;
+	int vf_id;
+	u16 vhca_id;
+	u8 migrate_cap:1;
+	u8 deferred_reset:1;
+	/* protect migration state */
+	struct mutex state_mutex;
+	enum vfio_device_mig_state mig_state;
+	/* protect the reset_done flow */
+	spinlock_t reset_lock;
+	struct mlx5_vf_migration_file *resuming_migf;
+	struct mlx5_vf_migration_file *saving_migf;
+	struct notifier_block nb;
+	struct mlx5_core_dev *mdev;
+	u8 mdev_detach:1;
+};
+
 int mlx5vf_cmd_suspend_vhca(struct pci_dev *pdev, u16 vhca_id, u16 op_mod);
 int mlx5vf_cmd_resume_vhca(struct pci_dev *pdev, u16 vhca_id, u16 op_mod);
 int mlx5vf_cmd_query_vhca_migration_state(struct pci_dev *pdev, u16 vhca_id,
 					  size_t *state_size);
 int mlx5vf_cmd_get_vhca_id(struct pci_dev *pdev, u16 function_id, u16 *vhca_id);
+void mlx5vf_cmd_set_migratable(struct mlx5vf_pci_core_device *mvdev);
+void mlx5vf_cmd_remove_migratable(struct mlx5vf_pci_core_device *mvdev);
 int mlx5vf_cmd_save_vhca_state(struct pci_dev *pdev, u16 vhca_id,
 			       struct mlx5_vf_migration_file *migf);
 int mlx5vf_cmd_load_vhca_state(struct pci_dev *pdev, u16 vhca_id,
 			       struct mlx5_vf_migration_file *migf);
+void mlx5vf_state_mutex_unlock(struct mlx5vf_pci_core_device *mvdev);
 #endif /* MLX5_VFIO_CMD_H */
diff --git a/drivers/vfio/pci/mlx5/main.c b/drivers/vfio/pci/mlx5/main.c
index bbec5d288fee..9716c87e31f9 100644
--- a/drivers/vfio/pci/mlx5/main.c
+++ b/drivers/vfio/pci/mlx5/main.c
@@ -17,7 +17,6 @@ 
 #include <linux/uaccess.h>
 #include <linux/vfio.h>
 #include <linux/sched/mm.h>
-#include <linux/vfio_pci_core.h>
 #include <linux/anon_inodes.h>
 
 #include "cmd.h"
@@ -25,20 +24,6 @@ 
 /* Arbitrary to prevent userspace from consuming endless memory */
 #define MAX_MIGRATION_SIZE (512*1024*1024)
 
-struct mlx5vf_pci_core_device {
-	struct vfio_pci_core_device core_device;
-	u16 vhca_id;
-	u8 migrate_cap:1;
-	u8 deferred_reset:1;
-	/* protect migration state */
-	struct mutex state_mutex;
-	enum vfio_device_mig_state mig_state;
-	/* protect the reset_done flow */
-	spinlock_t reset_lock;
-	struct mlx5_vf_migration_file *resuming_migf;
-	struct mlx5_vf_migration_file *saving_migf;
-};
-
 static struct page *
 mlx5vf_get_migration_page(struct mlx5_vf_migration_file *migf,
 			  unsigned long offset)
@@ -444,7 +429,7 @@  mlx5vf_pci_step_device_state_locked(struct mlx5vf_pci_core_device *mvdev,
  * This function is called in all state_mutex unlock cases to
  * handle a 'deferred_reset' if exists.
  */
-static void mlx5vf_state_mutex_unlock(struct mlx5vf_pci_core_device *mvdev)
+void mlx5vf_state_mutex_unlock(struct mlx5vf_pci_core_device *mvdev)
 {
 again:
 	spin_lock(&mvdev->reset_lock);
@@ -596,24 +581,7 @@  static int mlx5vf_pci_probe(struct pci_dev *pdev,
 	if (!mvdev)
 		return -ENOMEM;
 	vfio_pci_core_init_device(&mvdev->core_device, pdev, &mlx5vf_pci_ops);
-
-	if (pdev->is_virtfn) {
-		struct mlx5_core_dev *mdev =
-			mlx5_vf_get_core_dev(pdev);
-
-		if (mdev) {
-			if (MLX5_CAP_GEN(mdev, migration)) {
-				mvdev->migrate_cap = 1;
-				mvdev->core_device.vdev.migration_flags =
-					VFIO_MIGRATION_STOP_COPY |
-					VFIO_MIGRATION_P2P;
-				mutex_init(&mvdev->state_mutex);
-				spin_lock_init(&mvdev->reset_lock);
-			}
-			mlx5_vf_put_core_dev(mdev);
-		}
-	}
-
+	mlx5vf_cmd_set_migratable(mvdev);
 	ret = vfio_pci_core_register_device(&mvdev->core_device);
 	if (ret)
 		goto out_free;
@@ -622,6 +590,8 @@  static int mlx5vf_pci_probe(struct pci_dev *pdev,
 	return 0;
 
 out_free:
+	if (mvdev->migrate_cap)
+		mlx5vf_cmd_remove_migratable(mvdev);
 	vfio_pci_core_uninit_device(&mvdev->core_device);
 	kfree(mvdev);
 	return ret;
@@ -632,6 +602,8 @@  static void mlx5vf_pci_remove(struct pci_dev *pdev)
 	struct mlx5vf_pci_core_device *mvdev = dev_get_drvdata(&pdev->dev);
 
 	vfio_pci_core_unregister_device(&mvdev->core_device);
+	if (mvdev->migrate_cap)
+		mlx5vf_cmd_remove_migratable(mvdev);
 	vfio_pci_core_uninit_device(&mvdev->core_device);
 	kfree(mvdev);
 }