@@ -2740,14 +2740,49 @@ int ib_dma_virt_map_sg(struct ib_device *dev, struct scatterlist *sg, int nents)
EXPORT_SYMBOL(ib_dma_virt_map_sg);
#endif /* CONFIG_INFINIBAND_VIRT_DMA */
+static struct ib_device *get_ibdev_from_ndev(struct net_device *ndev)
+{
+ unsigned long index;
+ struct ib_device *dev;
+ int i;
+
+ down_read(&devices_rwsem);
+ xa_for_each_marked(&devices, index, dev, DEVICE_REGISTERED) {
+ if (!dev->ops.get_netdev)
+ continue;
+
+ for (i = 0; i < dev->phys_port_cnt; i++) {
+ struct net_device *netdev;
+
+ netdev = dev->ops.get_netdev(dev, i+1);
+ if (!netdev)
+ continue;
+
+ dev_put(netdev);
+ if (ndev == netdev) {
+ up_read(&devices_rwsem);
+ if (!ib_device_try_get(dev))
+ dev = NULL;
+ return dev;
+ }
+ }
+ }
+ up_read(&devices_rwsem);
+ return NULL;
+}
+
static int rdma_netns_notify(struct notifier_block *not_blk,
unsigned long event, void *arg)
{
struct net_device *ndev = netdev_notifier_info_to_dev(arg);
struct ib_device *ibdev = ib_device_get_by_netdev(ndev, RDMA_DRIVER_UNKNOWN);
- if (!ibdev)
- return NOTIFY_OK;
+ if (!ibdev) {
+ /* This is for MLX4/5 */
+ ibdev = get_ibdev_from_ndev(ndev);
+ if (!ibdev)
+ return NOTIFY_OK;
+ }
/* This will exclude IB device */
if (rdma_protocol_ib(ibdev, rdma_start_port(ibdev))) {