diff mbox series

[1/2] PCI: hv: Use IDR to generate transaction IDs for VMBus hardening

Message ID 20220318174848.290621-2-parri.andrea@gmail.com (mailing list archive)
State Superseded
Headers show
Series PCI: hv: Miscellaneous changes | expand

Commit Message

Andrea Parri March 18, 2022, 5:48 p.m. UTC
Currently, pointers to guest memory are passed to Hyper-V as transaction
IDs in hv_pci.  In the face of errors or malicious behavior in Hyper-V,
hv_pci should not expose or trust the transaction IDs returned by
Hyper-V to be valid guest memory addresses.  Instead, use small integers
generated by IDR as request (transaction) IDs.

Suggested-by: Michael Kelley <mikelley@microsoft.com>
Signed-off-by: Andrea Parri (Microsoft) <parri.andrea@gmail.com>
---
 drivers/pci/controller/pci-hyperv.c | 190 ++++++++++++++++++++--------
 1 file changed, 135 insertions(+), 55 deletions(-)

Comments

Saurabh Singh Sengar March 19, 2022, 7:47 a.m. UTC | #1
> -----Original Message-----
> From: Andrea Parri (Microsoft) <parri.andrea@gmail.com>
> Sent: 18 March 2022 23:19
> To: KY Srinivasan <kys@microsoft.com>; Haiyang Zhang
> <haiyangz@microsoft.com>; Stephen Hemminger
> <sthemmin@microsoft.com>; Wei Liu <wei.liu@kernel.org>; Dexuan Cui
> <decui@microsoft.com>; Michael Kelley (LINUX) <mikelley@microsoft.com>;
> Wei Hu <weh@microsoft.com>; Lorenzo Pieralisi
> <lorenzo.pieralisi@arm.com>; Rob Herring <robh@kernel.org>; Krzysztof
> Wilczynski <kw@linux.com>; Bjorn Helgaas <bhelgaas@google.com>
> Cc: linux-pci@vger.kernel.org; linux-hyperv@vger.kernel.org; linux-
> kernel@vger.kernel.org; Andrea Parri (Microsoft) <parri.andrea@gmail.com>
> Subject: [EXTERNAL] [PATCH 1/2] PCI: hv: Use IDR to generate transaction IDs
> for VMBus hardening
> 
> Currently, pointers to guest memory are passed to Hyper-V as transaction
> IDs in hv_pci.  In the face of errors or malicious behavior in Hyper-V,
> hv_pci should not expose or trust the transaction IDs returned by
> Hyper-V to be valid guest memory addresses.  Instead, use small integers
> generated by IDR as request (transaction) IDs.
> 
> Suggested-by: Michael Kelley <mikelley@microsoft.com>
> Signed-off-by: Andrea Parri (Microsoft) <parri.andrea@gmail.com>
> ---
>  drivers/pci/controller/pci-hyperv.c | 190 ++++++++++++++++++++--------
>  1 file changed, 135 insertions(+), 55 deletions(-)
> 
> diff --git a/drivers/pci/controller/pci-hyperv.c b/drivers/pci/controller/pci-
> hyperv.c
> index ae0bc2fee4ca8..fbc62aab08fdc 100644
> --- a/drivers/pci/controller/pci-hyperv.c
> +++ b/drivers/pci/controller/pci-hyperv.c
> @@ -495,6 +495,9 @@ struct hv_pcibus_device {
>  	spinlock_t device_list_lock;	/* Protect lists below */
>  	void __iomem *cfg_addr;
> 
> +	spinlock_t idr_lock; /* Serialize accesses to the IDR */
> +	struct idr idr; /* Map guest memory addresses */
> +
>  	struct list_head children;
>  	struct list_head dr_list;
> 
> @@ -1208,6 +1211,27 @@ static void hv_pci_read_config_compl(void
> *context, struct pci_response *resp,
>  	complete(&comp->comp_pkt.host_event);
>  }
> 
> +static inline int alloc_request_id(struct hv_pcibus_device *hbus,
> +				   void *ptr, gfp_t gfp)
> +{
> +	unsigned long flags;
> +	int req_id;
> +
> +	spin_lock_irqsave(&hbus->idr_lock, flags);
> +	req_id = idr_alloc(&hbus->idr, ptr, 1, 0, gfp);

[Saurabh Singh Sengar] Many a place we are using alloc_request_id with GFP_KERNEL, which results this allocation inside of spin lock with GFP_KERNEL.
Is this a good opportunity to use idr_preload ?

> +	spin_unlock_irqrestore(&hbus->idr_lock, flags);
> +	return req_id;
> +}
> +
> +static inline void remove_request_id(struct hv_pcibus_device *hbus, int
> req_id)
> +{
> +	unsigned long flags;
> +
> +	spin_lock_irqsave(&hbus->idr_lock, flags);
> +	idr_remove(&hbus->idr, req_id);
> +	spin_unlock_irqrestore(&hbus->idr_lock, flags);
> +}
> +
>  /**
>   * hv_read_config_block() - Sends a read config block request to
>   * the back-end driver running in the Hyper-V parent partition.
> @@ -1232,7 +1256,7 @@ static int hv_read_config_block(struct pci_dev
> *pdev, void *buf,
>  	} pkt;
>  	struct hv_read_config_compl comp_pkt;
>  	struct pci_read_block *read_blk;
> -	int ret;
> +	int req_id, ret;
> 
>  	if (len == 0 || len > HV_CONFIG_BLOCK_SIZE_MAX)
>  		return -EINVAL;
> @@ -1250,16 +1274,19 @@ static int hv_read_config_block(struct pci_dev
> *pdev, void *buf,
>  	read_blk->block_id = block_id;
>  	read_blk->bytes_requested = len;
> 
> +	req_id = alloc_request_id(hbus, &pkt.pkt, GFP_KERNEL);
> +	if (req_id < 0)
> +		return req_id;
> +
>  	ret = vmbus_sendpacket(hbus->hdev->channel, read_blk,
> -			       sizeof(*read_blk), (unsigned long)&pkt.pkt,
> -			       VM_PKT_DATA_INBAND,
> +			       sizeof(*read_blk), req_id,
> VM_PKT_DATA_INBAND,
> 
> VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
>  	if (ret)
> -		return ret;
> +		goto exit;
> 
>  	ret = wait_for_response(hbus->hdev,
> &comp_pkt.comp_pkt.host_event);
>  	if (ret)
> -		return ret;
> +		goto exit;
> 
>  	if (comp_pkt.comp_pkt.completion_status != 0 ||
>  	    comp_pkt.bytes_returned == 0) {
> @@ -1267,11 +1294,14 @@ static int hv_read_config_block(struct pci_dev
> *pdev, void *buf,
>  			"Read Config Block failed: 0x%x,
> bytes_returned=%d\n",
>  			comp_pkt.comp_pkt.completion_status,
>  			comp_pkt.bytes_returned);
> -		return -EIO;
> +		ret = -EIO;
> +		goto exit;
>  	}
> 
>  	*bytes_returned = comp_pkt.bytes_returned;
> -	return 0;
> +exit:
> +	remove_request_id(hbus, req_id);
> +	return ret;
>  }
> 
>  /**
> @@ -1313,8 +1343,8 @@ static int hv_write_config_block(struct pci_dev
> *pdev, void *buf,
>  	} pkt;
>  	struct hv_pci_compl comp_pkt;
>  	struct pci_write_block *write_blk;
> +	int req_id, ret;
>  	u32 pkt_size;
> -	int ret;
> 
>  	if (len == 0 || len > HV_CONFIG_BLOCK_SIZE_MAX)
>  		return -EINVAL;
> @@ -1340,24 +1370,30 @@ static int hv_write_config_block(struct pci_dev
> *pdev, void *buf,
>  	 */
>  	pkt_size += sizeof(pkt.reserved);
> 
> +	req_id = alloc_request_id(hbus, &pkt.pkt, GFP_KERNEL);
> +	if (req_id < 0)
> +		return req_id;
> +
>  	ret = vmbus_sendpacket(hbus->hdev->channel, write_blk, pkt_size,
> -			       (unsigned long)&pkt.pkt,
> VM_PKT_DATA_INBAND,
> +			       req_id, VM_PKT_DATA_INBAND,
> 
> VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
>  	if (ret)
> -		return ret;
> +		goto exit;
> 
>  	ret = wait_for_response(hbus->hdev, &comp_pkt.host_event);
>  	if (ret)
> -		return ret;
> +		goto exit;
> 
>  	if (comp_pkt.completion_status != 0) {
>  		dev_err(&hbus->hdev->device,
>  			"Write Config Block failed: 0x%x\n",
>  			comp_pkt.completion_status);
> -		return -EIO;
> +		ret = -EIO;
>  	}
> 
> -	return 0;
> +exit:
> +	remove_request_id(hbus, req_id);
> +	return ret;
>  }
> 
>  /**
> @@ -1407,7 +1443,7 @@ static void hv_int_desc_free(struct hv_pci_dev
> *hpdev,
>  	int_pkt->wslot.slot = hpdev->desc.win_slot.slot;
>  	int_pkt->int_desc = *int_desc;
>  	vmbus_sendpacket(hpdev->hbus->hdev->channel, int_pkt,
> sizeof(*int_pkt),
> -			 (unsigned long)&ctxt.pkt, VM_PKT_DATA_INBAND,
> 0);
> +			 0, VM_PKT_DATA_INBAND, 0);
>  	kfree(int_desc);
>  }
> 
> @@ -1688,9 +1724,8 @@ static void hv_compose_msi_msg(struct irq_data
> *data, struct msi_msg *msg)
>  			struct pci_create_interrupt3 v3;
>  		} int_pkts;
>  	} __packed ctxt;
> -
> +	int req_id, ret;
>  	u32 size;
> -	int ret;
> 
>  	pdev = msi_desc_to_pci_dev(irq_data_get_msi_desc(data));
>  	dest = irq_data_get_effective_affinity_mask(data);
> @@ -1750,15 +1785,18 @@ static void hv_compose_msi_msg(struct
> irq_data *data, struct msi_msg *msg)
>  		goto free_int_desc;
>  	}
> 
> +	req_id = alloc_request_id(hbus, &ctxt.pci_pkt, GFP_ATOMIC);
> +	if (req_id < 0)
> +		goto free_int_desc;
> +
>  	ret = vmbus_sendpacket(hpdev->hbus->hdev->channel,
> &ctxt.int_pkts,
> -			       size, (unsigned long)&ctxt.pci_pkt,
> -			       VM_PKT_DATA_INBAND,
> +			       size, req_id, VM_PKT_DATA_INBAND,
> 
> VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
>  	if (ret) {
>  		dev_err(&hbus->hdev->device,
>  			"Sending request for interrupt failed: 0x%x",
>  			comp.comp_pkt.completion_status);
> -		goto free_int_desc;
> +		goto remove_id;
>  	}
> 
>  	/*
> @@ -1811,7 +1849,7 @@ static void hv_compose_msi_msg(struct irq_data
> *data, struct msi_msg *msg)
>  		dev_err(&hbus->hdev->device,
>  			"Request for interrupt failed: 0x%x",
>  			comp.comp_pkt.completion_status);
> -		goto free_int_desc;
> +		goto remove_id;
>  	}
> 
>  	/*
> @@ -1827,11 +1865,14 @@ static void hv_compose_msi_msg(struct
> irq_data *data, struct msi_msg *msg)
>  	msg->address_lo = comp.int_desc.address & 0xffffffff;
>  	msg->data = comp.int_desc.data;
> 
> +	remove_request_id(hbus, req_id);
>  	put_pcichild(hpdev);
>  	return;
> 
>  enable_tasklet:
>  	tasklet_enable(&channel->callback_event);
> +remove_id:
> +	remove_request_id(hbus, req_id);
>  free_int_desc:
>  	kfree(int_desc);
>  drop_reference:
> @@ -2258,7 +2299,7 @@ static struct hv_pci_dev
> *new_pcichild_device(struct hv_pcibus_device *hbus,
>  		u8 buffer[sizeof(struct pci_child_message)];
>  	} pkt;
>  	unsigned long flags;
> -	int ret;
> +	int req_id, ret;
> 
>  	hpdev = kzalloc(sizeof(*hpdev), GFP_KERNEL);
>  	if (!hpdev)
> @@ -2275,16 +2316,19 @@ static struct hv_pci_dev
> *new_pcichild_device(struct hv_pcibus_device *hbus,
>  	res_req->message_type.type =
> PCI_QUERY_RESOURCE_REQUIREMENTS;
>  	res_req->wslot.slot = desc->win_slot.slot;
> 
> +	req_id = alloc_request_id(hbus, &pkt.init_packet, GFP_KERNEL);
> +	if (req_id < 0)
> +		goto error;
> +
>  	ret = vmbus_sendpacket(hbus->hdev->channel, res_req,
> -			       sizeof(struct pci_child_message),
> -			       (unsigned long)&pkt.init_packet,
> +			       sizeof(struct pci_child_message), req_id,
>  			       VM_PKT_DATA_INBAND,
> 
> VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
>  	if (ret)
> -		goto error;
> +		goto remove_id;
> 
>  	if (wait_for_response(hbus->hdev, &comp_pkt.host_event))
> -		goto error;
> +		goto remove_id;
> 
>  	hpdev->desc = *desc;
>  	refcount_set(&hpdev->refs, 1);
> @@ -2293,8 +2337,11 @@ static struct hv_pci_dev
> *new_pcichild_device(struct hv_pcibus_device *hbus,
> 
>  	list_add_tail(&hpdev->list_entry, &hbus->children);
>  	spin_unlock_irqrestore(&hbus->device_list_lock, flags);
> +	remove_request_id(hbus, req_id);
>  	return hpdev;
> 
> +remove_id:
> +	remove_request_id(hbus, req_id);
>  error:
>  	kfree(hpdev);
>  	return NULL;
> @@ -2648,8 +2695,7 @@ static void hv_eject_device_work(struct
> work_struct *work)
>  	ejct_pkt = (struct pci_eject_response *)&ctxt.pkt.message;
>  	ejct_pkt->message_type.type = PCI_EJECTION_COMPLETE;
>  	ejct_pkt->wslot.slot = hpdev->desc.win_slot.slot;
> -	vmbus_sendpacket(hbus->hdev->channel, ejct_pkt,
> -			 sizeof(*ejct_pkt), (unsigned long)&ctxt.pkt,
> +	vmbus_sendpacket(hbus->hdev->channel, ejct_pkt, sizeof(*ejct_pkt),
> 0,
>  			 VM_PKT_DATA_INBAND, 0);
> 
>  	/* For the get_pcichild() in hv_pci_eject_device() */
> @@ -2709,6 +2755,7 @@ static void hv_pci_onchannelcallback(void
> *context)
>  	struct pci_dev_inval_block *inval;
>  	struct pci_dev_incoming *dev_message;
>  	struct hv_pci_dev *hpdev;
> +	unsigned long flags;
> 
>  	buffer = kmalloc(bufferlen, GFP_ATOMIC);
>  	if (!buffer)
> @@ -2743,11 +2790,19 @@ static void hv_pci_onchannelcallback(void
> *context)
>  		switch (desc->type) {
>  		case VM_PKT_COMP:
> 
> -			/*
> -			 * The host is trusted, and thus it's safe to interpret
> -			 * this transaction ID as a pointer.
> -			 */
> -			comp_packet = (struct pci_packet *)req_id;
> +			if (req_id > INT_MAX) {
> +				dev_err_ratelimited(&hbus->hdev->device,
> +						    "Request ID >
> INT_MAX\n");
> +				break;
> +			}
> +			spin_lock_irqsave(&hbus->idr_lock, flags);
> +			comp_packet = (struct pci_packet *)idr_find(&hbus-
> >idr, req_id);
> +			spin_unlock_irqrestore(&hbus->idr_lock, flags);
> +			if (!comp_packet) {
> +				dev_warn_ratelimited(&hbus->hdev->device,
> +						     "Request ID not found\n");
> +				break;
> +			}
>  			response = (struct pci_response *)buffer;
>  			comp_packet->completion_func(comp_packet-
> >compl_ctxt,
>  						     response,
> @@ -2858,8 +2913,7 @@ static int hv_pci_protocol_negotiation(struct
> hv_device *hdev,
>  	struct pci_version_request *version_req;
>  	struct hv_pci_compl comp_pkt;
>  	struct pci_packet *pkt;
> -	int ret;
> -	int i;
> +	int req_id, ret, i;
> 
>  	/*
>  	 * Initiate the handshake with the host and negotiate
> @@ -2877,12 +2931,18 @@ static int hv_pci_protocol_negotiation(struct
> hv_device *hdev,
>  	version_req = (struct pci_version_request *)&pkt->message;
>  	version_req->message_type.type =
> PCI_QUERY_PROTOCOL_VERSION;
> 
> +	req_id = alloc_request_id(hbus, pkt, GFP_KERNEL);
> +	if (req_id < 0) {
> +		kfree(pkt);
> +		return req_id;
> +	}
> +
>  	for (i = 0; i < num_version; i++) {
>  		version_req->protocol_version = version[i];
>  		ret = vmbus_sendpacket(hdev->channel, version_req,
> -				sizeof(struct pci_version_request),
> -				(unsigned long)pkt, VM_PKT_DATA_INBAND,
> -
> 	VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
> +				       sizeof(struct pci_version_request),
> +				       req_id, VM_PKT_DATA_INBAND,
> +
> VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
>  		if (!ret)
>  			ret = wait_for_response(hdev,
> &comp_pkt.host_event);
> 
> @@ -2917,6 +2977,7 @@ static int hv_pci_protocol_negotiation(struct
> hv_device *hdev,
>  	ret = -EPROTO;
> 
>  exit:
> +	remove_request_id(hbus, req_id);
>  	kfree(pkt);
>  	return ret;
>  }
> @@ -3079,7 +3140,7 @@ static int hv_pci_enter_d0(struct hv_device *hdev)
>  	struct pci_bus_d0_entry *d0_entry;
>  	struct hv_pci_compl comp_pkt;
>  	struct pci_packet *pkt;
> -	int ret;
> +	int req_id, ret;
> 
>  	/*
>  	 * Tell the host that the bus is ready to use, and moved into the
> @@ -3098,8 +3159,14 @@ static int hv_pci_enter_d0(struct hv_device
> *hdev)
>  	d0_entry->message_type.type = PCI_BUS_D0ENTRY;
>  	d0_entry->mmio_base = hbus->mem_config->start;
> 
> +	req_id = alloc_request_id(hbus, pkt, GFP_KERNEL);
> +	if (req_id < 0) {
> +		kfree(pkt);
> +		return req_id;
> +	}
> +
>  	ret = vmbus_sendpacket(hdev->channel, d0_entry,
> sizeof(*d0_entry),
> -			       (unsigned long)pkt, VM_PKT_DATA_INBAND,
> +			       req_id, VM_PKT_DATA_INBAND,
> 
> VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
>  	if (!ret)
>  		ret = wait_for_response(hdev, &comp_pkt.host_event);
> @@ -3112,12 +3179,10 @@ static int hv_pci_enter_d0(struct hv_device
> *hdev)
>  			"PCI Pass-through VSP failed D0 Entry with status
> %x\n",
>  			comp_pkt.completion_status);
>  		ret = -EPROTO;
> -		goto exit;
>  	}
> 
> -	ret = 0;
> -
>  exit:
> +	remove_request_id(hbus, req_id);
>  	kfree(pkt);
>  	return ret;
>  }
> @@ -3175,11 +3240,10 @@ static int hv_send_resources_allocated(struct
> hv_device *hdev)
>  	struct pci_resources_assigned *res_assigned;
>  	struct pci_resources_assigned2 *res_assigned2;
>  	struct hv_pci_compl comp_pkt;
> +	int wslot, req_id, ret = 0;
>  	struct hv_pci_dev *hpdev;
>  	struct pci_packet *pkt;
>  	size_t size_res;
> -	int wslot;
> -	int ret;
> 
>  	size_res = (hbus->protocol_version < PCI_PROTOCOL_VERSION_1_2)
>  			? sizeof(*res_assigned) : sizeof(*res_assigned2);
> @@ -3188,7 +3252,11 @@ static int hv_send_resources_allocated(struct
> hv_device *hdev)
>  	if (!pkt)
>  		return -ENOMEM;
> 
> -	ret = 0;
> +	req_id = alloc_request_id(hbus, pkt, GFP_KERNEL);
> +	if (req_id < 0) {
> +		kfree(pkt);
> +		return req_id;
> +	}
> 
>  	for (wslot = 0; wslot < 256; wslot++) {
>  		hpdev = get_pcichild_wslot(hbus, wslot);
> @@ -3215,10 +3283,9 @@ static int hv_send_resources_allocated(struct
> hv_device *hdev)
>  		}
>  		put_pcichild(hpdev);
> 
> -		ret = vmbus_sendpacket(hdev->channel, &pkt->message,
> -				size_res, (unsigned long)pkt,
> -				VM_PKT_DATA_INBAND,
> -
> 	VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
> +		ret = vmbus_sendpacket(hdev->channel, &pkt->message,
> size_res,
> +				       req_id, VM_PKT_DATA_INBAND,
> +
> VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
>  		if (!ret)
>  			ret = wait_for_response(hdev,
> &comp_pkt.host_event);
>  		if (ret)
> @@ -3235,6 +3302,7 @@ static int hv_send_resources_allocated(struct
> hv_device *hdev)
>  		hbus->wslot_res_allocated = wslot;
>  	}
> 
> +	remove_request_id(hbus, req_id);
>  	kfree(pkt);
>  	return ret;
>  }
> @@ -3412,6 +3480,7 @@ static int hv_pci_probe(struct hv_device *hdev,
>  	spin_lock_init(&hbus->config_lock);
>  	spin_lock_init(&hbus->device_list_lock);
>  	spin_lock_init(&hbus->retarget_msi_interrupt_lock);
> +	spin_lock_init(&hbus->idr_lock);
>  	hbus->wq = alloc_ordered_workqueue("hv_pci_%x", 0,
>  					   hbus->bridge->domain_nr);
>  	if (!hbus->wq) {
> @@ -3419,6 +3488,7 @@ static int hv_pci_probe(struct hv_device *hdev,
>  		goto free_dom;
>  	}
> 
> +	idr_init(&hbus->idr);
>  	ret = vmbus_open(hdev->channel, pci_ring_size, pci_ring_size, NULL,
> 0,
>  			 hv_pci_onchannelcallback, hbus);
>  	if (ret)
> @@ -3537,6 +3607,7 @@ static int hv_pci_probe(struct hv_device *hdev,
>  	hv_free_config_window(hbus);
>  close:
>  	vmbus_close(hdev->channel);
> +	idr_destroy(&hbus->idr);
>  destroy_wq:
>  	destroy_workqueue(hbus->wq);
>  free_dom:
> @@ -3556,7 +3627,7 @@ static int hv_pci_bus_exit(struct hv_device *hdev,
> bool keep_devs)
>  	struct hv_pci_compl comp_pkt;
>  	struct hv_pci_dev *hpdev, *tmp;
>  	unsigned long flags;
> -	int ret;
> +	int req_id, ret;
> 
>  	/*
>  	 * After the host sends the RESCIND_CHANNEL message, it doesn't
> @@ -3599,18 +3670,23 @@ static int hv_pci_bus_exit(struct hv_device
> *hdev, bool keep_devs)
>  	pkt.teardown_packet.compl_ctxt = &comp_pkt;
>  	pkt.teardown_packet.message[0].type = PCI_BUS_D0EXIT;
> 
> +	req_id = alloc_request_id(hbus, &pkt.teardown_packet,
> GFP_KERNEL);
> +	if (req_id < 0)
> +		return req_id;
> +
>  	ret = vmbus_sendpacket(hdev->channel,
> &pkt.teardown_packet.message,
> -			       sizeof(struct pci_message),
> -			       (unsigned long)&pkt.teardown_packet,
> +			       sizeof(struct pci_message), req_id,
>  			       VM_PKT_DATA_INBAND,
> 
> VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
>  	if (ret)
> -		return ret;
> +		goto exit;
> 
>  	if (wait_for_completion_timeout(&comp_pkt.host_event, 10 * HZ) ==
> 0)
> -		return -ETIMEDOUT;
> +		ret = -ETIMEDOUT;
> 
> -	return 0;
> +exit:
> +	remove_request_id(hbus, req_id);
> +	return ret;
>  }
> 
>  /**
> @@ -3648,6 +3724,7 @@ static int hv_pci_remove(struct hv_device *hdev)
>  	ret = hv_pci_bus_exit(hdev, false);
> 
>  	vmbus_close(hdev->channel);
> +	idr_destroy(&hbus->idr);
> 
>  	iounmap(hbus->cfg_addr);
>  	hv_free_config_window(hbus);
> @@ -3704,6 +3781,7 @@ static int hv_pci_suspend(struct hv_device *hdev)
>  		return ret;
> 
>  	vmbus_close(hdev->channel);
> +	idr_destroy(&hbus->idr);
> 
>  	return 0;
>  }
> @@ -3749,6 +3827,7 @@ static int hv_pci_resume(struct hv_device *hdev)
> 
>  	hbus->state = hv_pcibus_init;
> 
> +	idr_init(&hbus->idr);
>  	ret = vmbus_open(hdev->channel, pci_ring_size, pci_ring_size, NULL,
> 0,
>  			 hv_pci_onchannelcallback, hbus);
>  	if (ret)
> @@ -3780,6 +3859,7 @@ static int hv_pci_resume(struct hv_device *hdev)
>  	return 0;
>  out:
>  	vmbus_close(hdev->channel);
> +	idr_destroy(&hbus->idr);
>  	return ret;
>  }
> 
> --
> 2.25.1
Andrea Parri March 19, 2022, 3:59 p.m. UTC | #2
> > @@ -1208,6 +1211,27 @@ static void hv_pci_read_config_compl(void
> > *context, struct pci_response *resp,
> >  	complete(&comp->comp_pkt.host_event);
> >  }
> > 
> > +static inline int alloc_request_id(struct hv_pcibus_device *hbus,
> > +				   void *ptr, gfp_t gfp)
> > +{
> > +	unsigned long flags;
> > +	int req_id;
> > +
> > +	spin_lock_irqsave(&hbus->idr_lock, flags);
> > +	req_id = idr_alloc(&hbus->idr, ptr, 1, 0, gfp);
> 
> [Saurabh Singh Sengar] Many a place we are using alloc_request_id with GFP_KERNEL, which results this allocation inside of spin lock with GFP_KERNEL.

That's a bug.


> Is this a good opportunity to use idr_preload ?

I'd rather fix (and 'simplify' a bit the interface) by doing:

static inline int alloc_request_id(struct hv_pcibus_device *hbus, void *ptr)
{
	unsigned long flags;
	int req_id;

	spin_lock_irqsave(&hbus->idr_lock, flags);
	req_id = idr_alloc(&hbus->idr, ptr, 1, 0, GFP_ATOMIC);
	spin_unlock_irqrestore(&hbus->idr_lock, flags);
	return req_id;
}

Thoughts?

Thanks,
  Andrea
Michael Kelley (LINUX) March 19, 2022, 4:20 p.m. UTC | #3
From: Andrea Parri (Microsoft) <parri.andrea@gmail.com> Sent: Friday, March 18, 2022 10:49 AM
> 
> Currently, pointers to guest memory are passed to Hyper-V as transaction
> IDs in hv_pci.  In the face of errors or malicious behavior in Hyper-V,
> hv_pci should not expose or trust the transaction IDs returned by
> Hyper-V to be valid guest memory addresses.  Instead, use small integers
> generated by IDR as request (transaction) IDs.

I had expected that this code would use the next_request_id_callback
mechanism because of the race conditions that mechanism solves.  And
to protect against a malicious Hyper-V sending a bogus second message
with the same requestID, the requestID needs to be freed in the
onchannelcallback function as is done with vmbus_request_addr().
The VMbus message traffic in this driver is a lot lower volume than in
netvsc (for example), but theoretically it seems like the same problems could
occur.  I think my earlier email sketching out a solution over-simplified the
problem and was misleading.

Michael

> 
> Suggested-by: Michael Kelley <mikelley@microsoft.com>
> Signed-off-by: Andrea Parri (Microsoft) <parri.andrea@gmail.com>
> ---
>  drivers/pci/controller/pci-hyperv.c | 190 ++++++++++++++++++++--------
>  1 file changed, 135 insertions(+), 55 deletions(-)
> 
> diff --git a/drivers/pci/controller/pci-hyperv.c b/drivers/pci/controller/pci-hyperv.c
> index ae0bc2fee4ca8..fbc62aab08fdc 100644
> --- a/drivers/pci/controller/pci-hyperv.c
> +++ b/drivers/pci/controller/pci-hyperv.c
> @@ -495,6 +495,9 @@ struct hv_pcibus_device {
>  	spinlock_t device_list_lock;	/* Protect lists below */
>  	void __iomem *cfg_addr;
> 
> +	spinlock_t idr_lock; /* Serialize accesses to the IDR */
> +	struct idr idr; /* Map guest memory addresses */
> +
>  	struct list_head children;
>  	struct list_head dr_list;
> 
> @@ -1208,6 +1211,27 @@ static void hv_pci_read_config_compl(void *context, struct
> pci_response *resp,
>  	complete(&comp->comp_pkt.host_event);
>  }
> 
> +static inline int alloc_request_id(struct hv_pcibus_device *hbus,
> +				   void *ptr, gfp_t gfp)
> +{
> +	unsigned long flags;
> +	int req_id;
> +
> +	spin_lock_irqsave(&hbus->idr_lock, flags);
> +	req_id = idr_alloc(&hbus->idr, ptr, 1, 0, gfp);
> +	spin_unlock_irqrestore(&hbus->idr_lock, flags);
> +	return req_id;
> +}
> +
> +static inline void remove_request_id(struct hv_pcibus_device *hbus, int req_id)
> +{
> +	unsigned long flags;
> +
> +	spin_lock_irqsave(&hbus->idr_lock, flags);
> +	idr_remove(&hbus->idr, req_id);
> +	spin_unlock_irqrestore(&hbus->idr_lock, flags);
> +}
> +
>  /**
>   * hv_read_config_block() - Sends a read config block request to
>   * the back-end driver running in the Hyper-V parent partition.
> @@ -1232,7 +1256,7 @@ static int hv_read_config_block(struct pci_dev *pdev, void
> *buf,
>  	} pkt;
>  	struct hv_read_config_compl comp_pkt;
>  	struct pci_read_block *read_blk;
> -	int ret;
> +	int req_id, ret;
> 
>  	if (len == 0 || len > HV_CONFIG_BLOCK_SIZE_MAX)
>  		return -EINVAL;
> @@ -1250,16 +1274,19 @@ static int hv_read_config_block(struct pci_dev *pdev, void
> *buf,
>  	read_blk->block_id = block_id;
>  	read_blk->bytes_requested = len;
> 
> +	req_id = alloc_request_id(hbus, &pkt.pkt, GFP_KERNEL);
> +	if (req_id < 0)
> +		return req_id;
> +
>  	ret = vmbus_sendpacket(hbus->hdev->channel, read_blk,
> -			       sizeof(*read_blk), (unsigned long)&pkt.pkt,
> -			       VM_PKT_DATA_INBAND,
> +			       sizeof(*read_blk), req_id, VM_PKT_DATA_INBAND,
>  			       VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
>  	if (ret)
> -		return ret;
> +		goto exit;
> 
>  	ret = wait_for_response(hbus->hdev, &comp_pkt.comp_pkt.host_event);
>  	if (ret)
> -		return ret;
> +		goto exit;
> 
>  	if (comp_pkt.comp_pkt.completion_status != 0 ||
>  	    comp_pkt.bytes_returned == 0) {
> @@ -1267,11 +1294,14 @@ static int hv_read_config_block(struct pci_dev *pdev, void
> *buf,
>  			"Read Config Block failed: 0x%x, bytes_returned=%d\n",
>  			comp_pkt.comp_pkt.completion_status,
>  			comp_pkt.bytes_returned);
> -		return -EIO;
> +		ret = -EIO;
> +		goto exit;
>  	}
> 
>  	*bytes_returned = comp_pkt.bytes_returned;
> -	return 0;
> +exit:
> +	remove_request_id(hbus, req_id);
> +	return ret;
>  }
> 
>  /**
> @@ -1313,8 +1343,8 @@ static int hv_write_config_block(struct pci_dev *pdev, void
> *buf,
>  	} pkt;
>  	struct hv_pci_compl comp_pkt;
>  	struct pci_write_block *write_blk;
> +	int req_id, ret;
>  	u32 pkt_size;
> -	int ret;
> 
>  	if (len == 0 || len > HV_CONFIG_BLOCK_SIZE_MAX)
>  		return -EINVAL;
> @@ -1340,24 +1370,30 @@ static int hv_write_config_block(struct pci_dev *pdev,
> void *buf,
>  	 */
>  	pkt_size += sizeof(pkt.reserved);
> 
> +	req_id = alloc_request_id(hbus, &pkt.pkt, GFP_KERNEL);
> +	if (req_id < 0)
> +		return req_id;
> +
>  	ret = vmbus_sendpacket(hbus->hdev->channel, write_blk, pkt_size,
> -			       (unsigned long)&pkt.pkt, VM_PKT_DATA_INBAND,
> +			       req_id, VM_PKT_DATA_INBAND,
>  			       VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
>  	if (ret)
> -		return ret;
> +		goto exit;
> 
>  	ret = wait_for_response(hbus->hdev, &comp_pkt.host_event);
>  	if (ret)
> -		return ret;
> +		goto exit;
> 
>  	if (comp_pkt.completion_status != 0) {
>  		dev_err(&hbus->hdev->device,
>  			"Write Config Block failed: 0x%x\n",
>  			comp_pkt.completion_status);
> -		return -EIO;
> +		ret = -EIO;
>  	}
> 
> -	return 0;
> +exit:
> +	remove_request_id(hbus, req_id);
> +	return ret;
>  }
> 
>  /**
> @@ -1407,7 +1443,7 @@ static void hv_int_desc_free(struct hv_pci_dev *hpdev,
>  	int_pkt->wslot.slot = hpdev->desc.win_slot.slot;
>  	int_pkt->int_desc = *int_desc;
>  	vmbus_sendpacket(hpdev->hbus->hdev->channel, int_pkt, sizeof(*int_pkt),
> -			 (unsigned long)&ctxt.pkt, VM_PKT_DATA_INBAND, 0);
> +			 0, VM_PKT_DATA_INBAND, 0);
>  	kfree(int_desc);
>  }
> 
> @@ -1688,9 +1724,8 @@ static void hv_compose_msi_msg(struct irq_data *data,
> struct msi_msg *msg)
>  			struct pci_create_interrupt3 v3;
>  		} int_pkts;
>  	} __packed ctxt;
> -
> +	int req_id, ret;
>  	u32 size;
> -	int ret;
> 
>  	pdev = msi_desc_to_pci_dev(irq_data_get_msi_desc(data));
>  	dest = irq_data_get_effective_affinity_mask(data);
> @@ -1750,15 +1785,18 @@ static void hv_compose_msi_msg(struct irq_data *data,
> struct msi_msg *msg)
>  		goto free_int_desc;
>  	}
> 
> +	req_id = alloc_request_id(hbus, &ctxt.pci_pkt, GFP_ATOMIC);
> +	if (req_id < 0)
> +		goto free_int_desc;
> +
>  	ret = vmbus_sendpacket(hpdev->hbus->hdev->channel, &ctxt.int_pkts,
> -			       size, (unsigned long)&ctxt.pci_pkt,
> -			       VM_PKT_DATA_INBAND,
> +			       size, req_id, VM_PKT_DATA_INBAND,
>  			       VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
>  	if (ret) {
>  		dev_err(&hbus->hdev->device,
>  			"Sending request for interrupt failed: 0x%x",
>  			comp.comp_pkt.completion_status);
> -		goto free_int_desc;
> +		goto remove_id;
>  	}
> 
>  	/*
> @@ -1811,7 +1849,7 @@ static void hv_compose_msi_msg(struct irq_data *data,
> struct msi_msg *msg)
>  		dev_err(&hbus->hdev->device,
>  			"Request for interrupt failed: 0x%x",
>  			comp.comp_pkt.completion_status);
> -		goto free_int_desc;
> +		goto remove_id;
>  	}
> 
>  	/*
> @@ -1827,11 +1865,14 @@ static void hv_compose_msi_msg(struct irq_data *data,
> struct msi_msg *msg)
>  	msg->address_lo = comp.int_desc.address & 0xffffffff;
>  	msg->data = comp.int_desc.data;
> 
> +	remove_request_id(hbus, req_id);
>  	put_pcichild(hpdev);
>  	return;
> 
>  enable_tasklet:
>  	tasklet_enable(&channel->callback_event);
> +remove_id:
> +	remove_request_id(hbus, req_id);
>  free_int_desc:
>  	kfree(int_desc);
>  drop_reference:
> @@ -2258,7 +2299,7 @@ static struct hv_pci_dev *new_pcichild_device(struct
> hv_pcibus_device *hbus,
>  		u8 buffer[sizeof(struct pci_child_message)];
>  	} pkt;
>  	unsigned long flags;
> -	int ret;
> +	int req_id, ret;
> 
>  	hpdev = kzalloc(sizeof(*hpdev), GFP_KERNEL);
>  	if (!hpdev)
> @@ -2275,16 +2316,19 @@ static struct hv_pci_dev *new_pcichild_device(struct
> hv_pcibus_device *hbus,
>  	res_req->message_type.type = PCI_QUERY_RESOURCE_REQUIREMENTS;
>  	res_req->wslot.slot = desc->win_slot.slot;
> 
> +	req_id = alloc_request_id(hbus, &pkt.init_packet, GFP_KERNEL);
> +	if (req_id < 0)
> +		goto error;
> +
>  	ret = vmbus_sendpacket(hbus->hdev->channel, res_req,
> -			       sizeof(struct pci_child_message),
> -			       (unsigned long)&pkt.init_packet,
> +			       sizeof(struct pci_child_message), req_id,
>  			       VM_PKT_DATA_INBAND,
>  			       VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
>  	if (ret)
> -		goto error;
> +		goto remove_id;
> 
>  	if (wait_for_response(hbus->hdev, &comp_pkt.host_event))
> -		goto error;
> +		goto remove_id;
> 
>  	hpdev->desc = *desc;
>  	refcount_set(&hpdev->refs, 1);
> @@ -2293,8 +2337,11 @@ static struct hv_pci_dev *new_pcichild_device(struct
> hv_pcibus_device *hbus,
> 
>  	list_add_tail(&hpdev->list_entry, &hbus->children);
>  	spin_unlock_irqrestore(&hbus->device_list_lock, flags);
> +	remove_request_id(hbus, req_id);
>  	return hpdev;
> 
> +remove_id:
> +	remove_request_id(hbus, req_id);
>  error:
>  	kfree(hpdev);
>  	return NULL;
> @@ -2648,8 +2695,7 @@ static void hv_eject_device_work(struct work_struct *work)
>  	ejct_pkt = (struct pci_eject_response *)&ctxt.pkt.message;
>  	ejct_pkt->message_type.type = PCI_EJECTION_COMPLETE;
>  	ejct_pkt->wslot.slot = hpdev->desc.win_slot.slot;
> -	vmbus_sendpacket(hbus->hdev->channel, ejct_pkt,
> -			 sizeof(*ejct_pkt), (unsigned long)&ctxt.pkt,
> +	vmbus_sendpacket(hbus->hdev->channel, ejct_pkt, sizeof(*ejct_pkt), 0,
>  			 VM_PKT_DATA_INBAND, 0);
> 
>  	/* For the get_pcichild() in hv_pci_eject_device() */
> @@ -2709,6 +2755,7 @@ static void hv_pci_onchannelcallback(void *context)
>  	struct pci_dev_inval_block *inval;
>  	struct pci_dev_incoming *dev_message;
>  	struct hv_pci_dev *hpdev;
> +	unsigned long flags;
> 
>  	buffer = kmalloc(bufferlen, GFP_ATOMIC);
>  	if (!buffer)
> @@ -2743,11 +2790,19 @@ static void hv_pci_onchannelcallback(void *context)
>  		switch (desc->type) {
>  		case VM_PKT_COMP:
> 
> -			/*
> -			 * The host is trusted, and thus it's safe to interpret
> -			 * this transaction ID as a pointer.
> -			 */
> -			comp_packet = (struct pci_packet *)req_id;
> +			if (req_id > INT_MAX) {
> +				dev_err_ratelimited(&hbus->hdev->device,
> +						    "Request ID > INT_MAX\n");
> +				break;
> +			}
> +			spin_lock_irqsave(&hbus->idr_lock, flags);
> +			comp_packet = (struct pci_packet *)idr_find(&hbus->idr,
> req_id);
> +			spin_unlock_irqrestore(&hbus->idr_lock, flags);
> +			if (!comp_packet) {
> +				dev_warn_ratelimited(&hbus->hdev->device,
> +						     "Request ID not found\n");
> +				break;
> +			}
>  			response = (struct pci_response *)buffer;
>  			comp_packet->completion_func(comp_packet->compl_ctxt,
>  						     response,
> @@ -2858,8 +2913,7 @@ static int hv_pci_protocol_negotiation(struct hv_device
> *hdev,
>  	struct pci_version_request *version_req;
>  	struct hv_pci_compl comp_pkt;
>  	struct pci_packet *pkt;
> -	int ret;
> -	int i;
> +	int req_id, ret, i;
> 
>  	/*
>  	 * Initiate the handshake with the host and negotiate
> @@ -2877,12 +2931,18 @@ static int hv_pci_protocol_negotiation(struct hv_device
> *hdev,
>  	version_req = (struct pci_version_request *)&pkt->message;
>  	version_req->message_type.type = PCI_QUERY_PROTOCOL_VERSION;
> 
> +	req_id = alloc_request_id(hbus, pkt, GFP_KERNEL);
> +	if (req_id < 0) {
> +		kfree(pkt);
> +		return req_id;
> +	}
> +
>  	for (i = 0; i < num_version; i++) {
>  		version_req->protocol_version = version[i];
>  		ret = vmbus_sendpacket(hdev->channel, version_req,
> -				sizeof(struct pci_version_request),
> -				(unsigned long)pkt, VM_PKT_DATA_INBAND,
> -
> 	VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
> +				       sizeof(struct pci_version_request),
> +				       req_id, VM_PKT_DATA_INBAND,
> +
> VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
>  		if (!ret)
>  			ret = wait_for_response(hdev, &comp_pkt.host_event);
> 
> @@ -2917,6 +2977,7 @@ static int hv_pci_protocol_negotiation(struct hv_device
> *hdev,
>  	ret = -EPROTO;
> 
>  exit:
> +	remove_request_id(hbus, req_id);
>  	kfree(pkt);
>  	return ret;
>  }
> @@ -3079,7 +3140,7 @@ static int hv_pci_enter_d0(struct hv_device *hdev)
>  	struct pci_bus_d0_entry *d0_entry;
>  	struct hv_pci_compl comp_pkt;
>  	struct pci_packet *pkt;
> -	int ret;
> +	int req_id, ret;
> 
>  	/*
>  	 * Tell the host that the bus is ready to use, and moved into the
> @@ -3098,8 +3159,14 @@ static int hv_pci_enter_d0(struct hv_device *hdev)
>  	d0_entry->message_type.type = PCI_BUS_D0ENTRY;
>  	d0_entry->mmio_base = hbus->mem_config->start;
> 
> +	req_id = alloc_request_id(hbus, pkt, GFP_KERNEL);
> +	if (req_id < 0) {
> +		kfree(pkt);
> +		return req_id;
> +	}
> +
>  	ret = vmbus_sendpacket(hdev->channel, d0_entry, sizeof(*d0_entry),
> -			       (unsigned long)pkt, VM_PKT_DATA_INBAND,
> +			       req_id, VM_PKT_DATA_INBAND,
>  			       VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
>  	if (!ret)
>  		ret = wait_for_response(hdev, &comp_pkt.host_event);
> @@ -3112,12 +3179,10 @@ static int hv_pci_enter_d0(struct hv_device *hdev)
>  			"PCI Pass-through VSP failed D0 Entry with status %x\n",
>  			comp_pkt.completion_status);
>  		ret = -EPROTO;
> -		goto exit;
>  	}
> 
> -	ret = 0;
> -
>  exit:
> +	remove_request_id(hbus, req_id);
>  	kfree(pkt);
>  	return ret;
>  }
> @@ -3175,11 +3240,10 @@ static int hv_send_resources_allocated(struct hv_device
> *hdev)
>  	struct pci_resources_assigned *res_assigned;
>  	struct pci_resources_assigned2 *res_assigned2;
>  	struct hv_pci_compl comp_pkt;
> +	int wslot, req_id, ret = 0;
>  	struct hv_pci_dev *hpdev;
>  	struct pci_packet *pkt;
>  	size_t size_res;
> -	int wslot;
> -	int ret;
> 
>  	size_res = (hbus->protocol_version < PCI_PROTOCOL_VERSION_1_2)
>  			? sizeof(*res_assigned) : sizeof(*res_assigned2);
> @@ -3188,7 +3252,11 @@ static int hv_send_resources_allocated(struct hv_device
> *hdev)
>  	if (!pkt)
>  		return -ENOMEM;
> 
> -	ret = 0;
> +	req_id = alloc_request_id(hbus, pkt, GFP_KERNEL);
> +	if (req_id < 0) {
> +		kfree(pkt);
> +		return req_id;
> +	}
> 
>  	for (wslot = 0; wslot < 256; wslot++) {
>  		hpdev = get_pcichild_wslot(hbus, wslot);
> @@ -3215,10 +3283,9 @@ static int hv_send_resources_allocated(struct hv_device
> *hdev)
>  		}
>  		put_pcichild(hpdev);
> 
> -		ret = vmbus_sendpacket(hdev->channel, &pkt->message,
> -				size_res, (unsigned long)pkt,
> -				VM_PKT_DATA_INBAND,
> -
> 	VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
> +		ret = vmbus_sendpacket(hdev->channel, &pkt->message, size_res,
> +				       req_id, VM_PKT_DATA_INBAND,
> +
> VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
>  		if (!ret)
>  			ret = wait_for_response(hdev, &comp_pkt.host_event);
>  		if (ret)
> @@ -3235,6 +3302,7 @@ static int hv_send_resources_allocated(struct hv_device
> *hdev)
>  		hbus->wslot_res_allocated = wslot;
>  	}
> 
> +	remove_request_id(hbus, req_id);
>  	kfree(pkt);
>  	return ret;
>  }
> @@ -3412,6 +3480,7 @@ static int hv_pci_probe(struct hv_device *hdev,
>  	spin_lock_init(&hbus->config_lock);
>  	spin_lock_init(&hbus->device_list_lock);
>  	spin_lock_init(&hbus->retarget_msi_interrupt_lock);
> +	spin_lock_init(&hbus->idr_lock);
>  	hbus->wq = alloc_ordered_workqueue("hv_pci_%x", 0,
>  					   hbus->bridge->domain_nr);
>  	if (!hbus->wq) {
> @@ -3419,6 +3488,7 @@ static int hv_pci_probe(struct hv_device *hdev,
>  		goto free_dom;
>  	}
> 
> +	idr_init(&hbus->idr);
>  	ret = vmbus_open(hdev->channel, pci_ring_size, pci_ring_size, NULL, 0,
>  			 hv_pci_onchannelcallback, hbus);
>  	if (ret)
> @@ -3537,6 +3607,7 @@ static int hv_pci_probe(struct hv_device *hdev,
>  	hv_free_config_window(hbus);
>  close:
>  	vmbus_close(hdev->channel);
> +	idr_destroy(&hbus->idr);
>  destroy_wq:
>  	destroy_workqueue(hbus->wq);
>  free_dom:
> @@ -3556,7 +3627,7 @@ static int hv_pci_bus_exit(struct hv_device *hdev, bool
> keep_devs)
>  	struct hv_pci_compl comp_pkt;
>  	struct hv_pci_dev *hpdev, *tmp;
>  	unsigned long flags;
> -	int ret;
> +	int req_id, ret;
> 
>  	/*
>  	 * After the host sends the RESCIND_CHANNEL message, it doesn't
> @@ -3599,18 +3670,23 @@ static int hv_pci_bus_exit(struct hv_device *hdev, bool
> keep_devs)
>  	pkt.teardown_packet.compl_ctxt = &comp_pkt;
>  	pkt.teardown_packet.message[0].type = PCI_BUS_D0EXIT;
> 
> +	req_id = alloc_request_id(hbus, &pkt.teardown_packet, GFP_KERNEL);
> +	if (req_id < 0)
> +		return req_id;
> +
>  	ret = vmbus_sendpacket(hdev->channel, &pkt.teardown_packet.message,
> -			       sizeof(struct pci_message),
> -			       (unsigned long)&pkt.teardown_packet,
> +			       sizeof(struct pci_message), req_id,
>  			       VM_PKT_DATA_INBAND,
>  			       VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
>  	if (ret)
> -		return ret;
> +		goto exit;
> 
>  	if (wait_for_completion_timeout(&comp_pkt.host_event, 10 * HZ) == 0)
> -		return -ETIMEDOUT;
> +		ret = -ETIMEDOUT;
> 
> -	return 0;
> +exit:
> +	remove_request_id(hbus, req_id);
> +	return ret;
>  }
> 
>  /**
> @@ -3648,6 +3724,7 @@ static int hv_pci_remove(struct hv_device *hdev)
>  	ret = hv_pci_bus_exit(hdev, false);
> 
>  	vmbus_close(hdev->channel);
> +	idr_destroy(&hbus->idr);
> 
>  	iounmap(hbus->cfg_addr);
>  	hv_free_config_window(hbus);
> @@ -3704,6 +3781,7 @@ static int hv_pci_suspend(struct hv_device *hdev)
>  		return ret;
> 
>  	vmbus_close(hdev->channel);
> +	idr_destroy(&hbus->idr);
> 
>  	return 0;
>  }
> @@ -3749,6 +3827,7 @@ static int hv_pci_resume(struct hv_device *hdev)
> 
>  	hbus->state = hv_pcibus_init;
> 
> +	idr_init(&hbus->idr);
>  	ret = vmbus_open(hdev->channel, pci_ring_size, pci_ring_size, NULL, 0,
>  			 hv_pci_onchannelcallback, hbus);
>  	if (ret)
> @@ -3780,6 +3859,7 @@ static int hv_pci_resume(struct hv_device *hdev)
>  	return 0;
>  out:
>  	vmbus_close(hdev->channel);
> +	idr_destroy(&hbus->idr);
>  	return ret;
>  }
> 
> --
> 2.25.1
Saurabh Singh Sengar March 20, 2022, 5:53 a.m. UTC | #4
> -----Original Message-----
> From: Andrea Parri <parri.andrea@gmail.com>
> Sent: 19 March 2022 21:29
> To: Saurabh Singh Sengar <ssengar@microsoft.com>
> Cc: KY Srinivasan <kys@microsoft.com>; Haiyang Zhang
> <haiyangz@microsoft.com>; Stephen Hemminger
> <sthemmin@microsoft.com>; Wei Liu <wei.liu@kernel.org>; Dexuan Cui
> <decui@microsoft.com>; Michael Kelley (LINUX) <mikelley@microsoft.com>;
> Wei Hu <weh@microsoft.com>; Lorenzo Pieralisi
> <lorenzo.pieralisi@arm.com>; Rob Herring <robh@kernel.org>; Krzysztof
> Wilczynski <kw@linux.com>; Bjorn Helgaas <bhelgaas@google.com>; linux-
> pci@vger.kernel.org; linux-hyperv@vger.kernel.org; linux-
> kernel@vger.kernel.org
> Subject: Re: [EXTERNAL] [PATCH 1/2] PCI: hv: Use IDR to generate transaction
> IDs for VMBus hardening
> 
> > > @@ -1208,6 +1211,27 @@ static void hv_pci_read_config_compl(void
> > > *context, struct pci_response *resp,
> > >  	complete(&comp->comp_pkt.host_event);
> > >  }
> > >
> > > +static inline int alloc_request_id(struct hv_pcibus_device *hbus,
> > > +				   void *ptr, gfp_t gfp)
> > > +{
> > > +	unsigned long flags;
> > > +	int req_id;
> > > +
> > > +	spin_lock_irqsave(&hbus->idr_lock, flags);
> > > +	req_id = idr_alloc(&hbus->idr, ptr, 1, 0, gfp);
> >
> > [Saurabh Singh Sengar] Many a place we are using alloc_request_id with
> GFP_KERNEL, which results this allocation inside of spin lock with
> GFP_KERNEL.
> 
> That's a bug.
> 
> 
> > Is this a good opportunity to use idr_preload ?
> 
> I'd rather fix (and 'simplify' a bit the interface) by doing:
> 
> static inline int alloc_request_id(struct hv_pcibus_device *hbus, void *ptr)
> {
> 	unsigned long flags;
> 	int req_id;
> 
> 	spin_lock_irqsave(&hbus->idr_lock, flags);
> 	req_id = idr_alloc(&hbus->idr, ptr, 1, 0, GFP_ATOMIC);
> 	spin_unlock_irqrestore(&hbus->idr_lock, flags);
> 	return req_id;
> }
> 
> Thoughts?
[Saurabh Sengar] Yes, if we are fine to use GFP_ATOMIC, this makes perfect sense.
Once fixed, please add: Reviewed-by: Saurabh Sengar <ssengar@microsoft.com>

> 
> Thanks,
>   Andrea
Andrea Parri March 20, 2022, 2:58 p.m. UTC | #5
On Sat, Mar 19, 2022 at 04:20:13PM +0000, Michael Kelley (LINUX) wrote:
> From: Andrea Parri (Microsoft) <parri.andrea@gmail.com> Sent: Friday, March 18, 2022 10:49 AM
> > 
> > Currently, pointers to guest memory are passed to Hyper-V as transaction
> > IDs in hv_pci.  In the face of errors or malicious behavior in Hyper-V,
> > hv_pci should not expose or trust the transaction IDs returned by
> > Hyper-V to be valid guest memory addresses.  Instead, use small integers
> > generated by IDR as request (transaction) IDs.
> 
> I had expected that this code would use the next_request_id_callback
> mechanism because of the race conditions that mechanism solves.  And
> to protect against a malicious Hyper-V sending a bogus second message
> with the same requestID, the requestID needs to be freed in the
> onchannelcallback function as is done with vmbus_request_addr().

I think I should elaborate on the design underlying this submission;
roughly, the present solution diverges from the 'generic' requestor
mechanism you mentioned above in two main aspects:

  A) it 'moves' the ID removal into hv_compose_msi_msg() and other
     functions,

  B) it adopts some ad-hoc locking scheme in the channel callback.

AFAICT, such changes preserve the 'confidentiality' and correctness
guarantees of the generic approach (modulo the issue discussed here
with Saurabh).

These changes are justified by the bug/fix discussed in 2/2.  For
concreteness, consider a solution based on the VMbus requestor as
reported at the end of this email.

AFAICT, this solution can't fix the bug discussed in 2/2.  Moreover
(and looking back at (A-B)), we observe that:

  1) locking in the channel callback is not quite as desired: we'd
     want a request_addr_callback_nolock() say and 'protected' it
     together with ->completion_func();

  2) hv_compose_msi_msg() doesn't know the value of the request ID
     it has allocated (hv_compose_msi_msg() -> vmbus_sendpacket();
     cf. also remove_request_id() in the current submission).

Hope this helps clarify the problems at stake, and move fortward to a
'final' solution...

Thanks,
  Andrea


diff --git a/drivers/pci/controller/pci-hyperv.c b/drivers/pci/controller/pci-hyperv.c
index ae0bc2fee4ca8..bd99dd12d367b 100644
--- a/drivers/pci/controller/pci-hyperv.c
+++ b/drivers/pci/controller/pci-hyperv.c
@@ -91,6 +91,9 @@ static enum pci_protocol_version_t pci_protocol_versions[] = {
 /* space for 32bit serial number as string */
 #define SLOT_NAME_SIZE 11
 
+/* Size of requestor for VMbus */
+#define HV_PCI_RQSTOR_SIZE 64
+
 /*
  * Message Types
  */
@@ -1407,7 +1410,7 @@ static void hv_int_desc_free(struct hv_pci_dev *hpdev,
 	int_pkt->wslot.slot = hpdev->desc.win_slot.slot;
 	int_pkt->int_desc = *int_desc;
 	vmbus_sendpacket(hpdev->hbus->hdev->channel, int_pkt, sizeof(*int_pkt),
-			 (unsigned long)&ctxt.pkt, VM_PKT_DATA_INBAND, 0);
+			 0, VM_PKT_DATA_INBAND, 0);
 	kfree(int_desc);
 }
 
@@ -2649,7 +2652,7 @@ static void hv_eject_device_work(struct work_struct *work)
 	ejct_pkt->message_type.type = PCI_EJECTION_COMPLETE;
 	ejct_pkt->wslot.slot = hpdev->desc.win_slot.slot;
 	vmbus_sendpacket(hbus->hdev->channel, ejct_pkt,
-			 sizeof(*ejct_pkt), (unsigned long)&ctxt.pkt,
+			 sizeof(*ejct_pkt), 0,
 			 VM_PKT_DATA_INBAND, 0);
 
 	/* For the get_pcichild() in hv_pci_eject_device() */
@@ -2696,8 +2699,9 @@ static void hv_pci_onchannelcallback(void *context)
 	const int packet_size = 0x100;
 	int ret;
 	struct hv_pcibus_device *hbus = context;
+	struct vmbus_channel *chan = hbus->hdev->channel;
 	u32 bytes_recvd;
-	u64 req_id;
+	u64 req_id, req_addr;
 	struct vmpacket_descriptor *desc;
 	unsigned char *buffer;
 	int bufferlen = packet_size;
@@ -2743,11 +2747,13 @@ static void hv_pci_onchannelcallback(void *context)
 		switch (desc->type) {
 		case VM_PKT_COMP:
 
-			/*
-			 * The host is trusted, and thus it's safe to interpret
-			 * this transaction ID as a pointer.
-			 */
-			comp_packet = (struct pci_packet *)req_id;
+			req_addr = chan->request_addr_callback(chan, req_id);
+			if (!req_addr || req_addr == VMBUS_RQST_ERROR) {
+				dev_warn_ratelimited(&hbus->hdev->device,
+						     "Invalid request ID\n");
+				break;
+			}
+			comp_packet = (struct pci_packet *)req_addr;
 			response = (struct pci_response *)buffer;
 			comp_packet->completion_func(comp_packet->compl_ctxt,
 						     response,
@@ -3419,6 +3425,10 @@ static int hv_pci_probe(struct hv_device *hdev,
 		goto free_dom;
 	}
 
+	hdev->channel->next_request_id_callback = vmbus_next_request_id;
+	hdev->channel->request_addr_callback = vmbus_request_addr;
+	hdev->channel->rqstor_size = HV_PCI_RQSTOR_SIZE;
+
 	ret = vmbus_open(hdev->channel, pci_ring_size, pci_ring_size, NULL, 0,
 			 hv_pci_onchannelcallback, hbus);
 	if (ret)
@@ -3749,6 +3759,10 @@ static int hv_pci_resume(struct hv_device *hdev)
 
 	hbus->state = hv_pcibus_init;
 
+	hdev->channel->next_request_id_callback = vmbus_next_request_id;
+	hdev->channel->request_addr_callback = vmbus_request_addr;
+	hdev->channel->rqstor_size = HV_PCI_RQSTOR_SIZE;
+
 	ret = vmbus_open(hdev->channel, pci_ring_size, pci_ring_size, NULL, 0,
 			 hv_pci_onchannelcallback, hbus);
 	if (ret)
Michael Kelley (LINUX) March 21, 2022, 6:23 p.m. UTC | #6
From: Andrea Parri <parri.andrea@gmail.com> Sent: Sunday, March 20, 2022 7:59 AM
> 
> On Sat, Mar 19, 2022 at 04:20:13PM +0000, Michael Kelley (LINUX) wrote:
> > From: Andrea Parri (Microsoft) <parri.andrea@gmail.com> Sent: Friday, March 18,
> 2022 10:49 AM
> > >
> > > Currently, pointers to guest memory are passed to Hyper-V as transaction
> > > IDs in hv_pci.  In the face of errors or malicious behavior in Hyper-V,
> > > hv_pci should not expose or trust the transaction IDs returned by
> > > Hyper-V to be valid guest memory addresses.  Instead, use small integers
> > > generated by IDR as request (transaction) IDs.
> >
> > I had expected that this code would use the next_request_id_callback
> > mechanism because of the race conditions that mechanism solves.  And
> > to protect against a malicious Hyper-V sending a bogus second message
> > with the same requestID, the requestID needs to be freed in the
> > onchannelcallback function as is done with vmbus_request_addr().
> 
> I think I should elaborate on the design underlying this submission;
> roughly, the present solution diverges from the 'generic' requestor
> mechanism you mentioned above in two main aspects:
> 
>   A) it 'moves' the ID removal into hv_compose_msi_msg() and other
>      functions,

Right.  A key implication is that this patch allows the completion
function to be called multiple times, if Hyper-V were to be malicious
and send multiple responses with the same requestID.  This is OK as
long as the completion functions are idempotent, which after looking,
I think they are in this driver.

Furthermore, this patch allows the completion function to run anytime
between when the requestID is created and when it is deleted.  This
patch creates the requestID just before calling vmbus_sendpacket(),
which is good.  The requestID is deleted later in the various functions.
I saw only one potential problem, which is in new_pcichild_device(),
where the new hpdev is added to a global list before the requestID is
deleted. There's a window where the completion function could run
and update the probed_bar[] values asynchronously after the hpdev is
on the global list.  I don't know if this is a problem or not, but it could
be prevented by deleting the requestID a little earlier in the function.

> 
>   B) it adopts some ad-hoc locking scheme in the channel callback.
> 
> AFAICT, such changes preserve the 'confidentiality' and correctness
> guarantees of the generic approach (modulo the issue discussed here
> with Saurabh).

Yes, I agree, assuming the current functionality of the completion
functions.

> 
> These changes are justified by the bug/fix discussed in 2/2.  For
> concreteness, consider a solution based on the VMbus requestor as
> reported at the end of this email.
> 
> AFAICT, this solution can't fix the bug discussed in 2/2.  Moreover
> (and looking back at (A-B)), we observe that:
> 
>   1) locking in the channel callback is not quite as desired: we'd
>      want a request_addr_callback_nolock() say and 'protected' it
>      together with ->completion_func();

I'm not understanding this point.  Could you clarify?

> 
>   2) hv_compose_msi_msg() doesn't know the value of the request ID
>      it has allocated (hv_compose_msi_msg() -> vmbus_sendpacket();
>      cf. also remove_request_id() in the current submission).

Agreed.  This would have to be addressed by adding another version of
vmbus_sendpacket() that returns the request ID.

> 
> Hope this helps clarify the problems at stake, and move forward to a
> 'final' solution...

I think there's a reasonable way for the vmbus_next_request_id()
mechanism to solve the problem in Patch 2/2 (if a new version of
vmbus_sendpacket is added).  To me, that mechanism seems safer
in that it restricts the completion function to running just once
per requestID.  With this patch, we must remember that the
completion functions must remain idempotent.

But I can go either way.  I can give an OK on this solution if that's
the preferred path.  Other input is also welcome ...

Michael

> 
> Thanks,
>   Andrea
> 
> 
> diff --git a/drivers/pci/controller/pci-hyperv.c b/drivers/pci/controller/pci-hyperv.c
> index ae0bc2fee4ca8..bd99dd12d367b 100644
> --- a/drivers/pci/controller/pci-hyperv.c
> +++ b/drivers/pci/controller/pci-hyperv.c
> @@ -91,6 +91,9 @@ static enum pci_protocol_version_t pci_protocol_versions[] = {
>  /* space for 32bit serial number as string */
>  #define SLOT_NAME_SIZE 11
> 
> +/* Size of requestor for VMbus */
> +#define HV_PCI_RQSTOR_SIZE 64
> +
>  /*
>   * Message Types
>   */
> @@ -1407,7 +1410,7 @@ static void hv_int_desc_free(struct hv_pci_dev *hpdev,
>  	int_pkt->wslot.slot = hpdev->desc.win_slot.slot;
>  	int_pkt->int_desc = *int_desc;
>  	vmbus_sendpacket(hpdev->hbus->hdev->channel, int_pkt, sizeof(*int_pkt),
> -			 (unsigned long)&ctxt.pkt, VM_PKT_DATA_INBAND, 0);
> +			 0, VM_PKT_DATA_INBAND, 0);
>  	kfree(int_desc);
>  }
> 
> @@ -2649,7 +2652,7 @@ static void hv_eject_device_work(struct work_struct *work)
>  	ejct_pkt->message_type.type = PCI_EJECTION_COMPLETE;
>  	ejct_pkt->wslot.slot = hpdev->desc.win_slot.slot;
>  	vmbus_sendpacket(hbus->hdev->channel, ejct_pkt,
> -			 sizeof(*ejct_pkt), (unsigned long)&ctxt.pkt,
> +			 sizeof(*ejct_pkt), 0,
>  			 VM_PKT_DATA_INBAND, 0);
> 
>  	/* For the get_pcichild() in hv_pci_eject_device() */
> @@ -2696,8 +2699,9 @@ static void hv_pci_onchannelcallback(void *context)
>  	const int packet_size = 0x100;
>  	int ret;
>  	struct hv_pcibus_device *hbus = context;
> +	struct vmbus_channel *chan = hbus->hdev->channel;
>  	u32 bytes_recvd;
> -	u64 req_id;
> +	u64 req_id, req_addr;
>  	struct vmpacket_descriptor *desc;
>  	unsigned char *buffer;
>  	int bufferlen = packet_size;
> @@ -2743,11 +2747,13 @@ static void hv_pci_onchannelcallback(void *context)
>  		switch (desc->type) {
>  		case VM_PKT_COMP:
> 
> -			/*
> -			 * The host is trusted, and thus it's safe to interpret
> -			 * this transaction ID as a pointer.
> -			 */
> -			comp_packet = (struct pci_packet *)req_id;
> +			req_addr = chan->request_addr_callback(chan, req_id);
> +			if (!req_addr || req_addr == VMBUS_RQST_ERROR) {
> +				dev_warn_ratelimited(&hbus->hdev->device,
> +						     "Invalid request ID\n");
> +				break;
> +			}
> +			comp_packet = (struct pci_packet *)req_addr;
>  			response = (struct pci_response *)buffer;
>  			comp_packet->completion_func(comp_packet->compl_ctxt,
>  						     response,
> @@ -3419,6 +3425,10 @@ static int hv_pci_probe(struct hv_device *hdev,
>  		goto free_dom;
>  	}
> 
> +	hdev->channel->next_request_id_callback = vmbus_next_request_id;
> +	hdev->channel->request_addr_callback = vmbus_request_addr;
> +	hdev->channel->rqstor_size = HV_PCI_RQSTOR_SIZE;
> +
>  	ret = vmbus_open(hdev->channel, pci_ring_size, pci_ring_size, NULL, 0,
>  			 hv_pci_onchannelcallback, hbus);
>  	if (ret)
> @@ -3749,6 +3759,10 @@ static int hv_pci_resume(struct hv_device *hdev)
> 
>  	hbus->state = hv_pcibus_init;
> 
> +	hdev->channel->next_request_id_callback = vmbus_next_request_id;
> +	hdev->channel->request_addr_callback = vmbus_request_addr;
> +	hdev->channel->rqstor_size = HV_PCI_RQSTOR_SIZE;
> +
>  	ret = vmbus_open(hdev->channel, pci_ring_size, pci_ring_size, NULL, 0,
>  			 hv_pci_onchannelcallback, hbus);
>  	if (ret)
Andrea Parri March 22, 2022, 12:51 p.m. UTC | #7
> > I think I should elaborate on the design underlying this submission;
> > roughly, the present solution diverges from the 'generic' requestor
> > mechanism you mentioned above in two main aspects:
> > 
> >   A) it 'moves' the ID removal into hv_compose_msi_msg() and other
> >      functions,
> 
> Right.  A key implication is that this patch allows the completion
> function to be called multiple times, if Hyper-V were to be malicious
> and send multiple responses with the same requestID.  This is OK as
> long as the completion functions are idempotent, which after looking,
> I think they are in this driver.
> 
> Furthermore, this patch allows the completion function to run anytime
> between when the requestID is created and when it is deleted.  This
> patch creates the requestID just before calling vmbus_sendpacket(),
> which is good.  The requestID is deleted later in the various functions.
> I saw only one potential problem, which is in new_pcichild_device(),
> where the new hpdev is added to a global list before the requestID is
> deleted. There's a window where the completion function could run
> and update the probed_bar[] values asynchronously after the hpdev is
> on the global list.  I don't know if this is a problem or not, but it could
> be prevented by deleting the requestID a little earlier in the function.
> 
> > 
> >   B) it adopts some ad-hoc locking scheme in the channel callback.
> > 
> > AFAICT, such changes preserve the 'confidentiality' and correctness
> > guarantees of the generic approach (modulo the issue discussed here
> > with Saurabh).
> 
> Yes, I agree, assuming the current functionality of the completion
> functions.
> 
> > 
> > These changes are justified by the bug/fix discussed in 2/2.  For
> > concreteness, consider a solution based on the VMbus requestor as
> > reported at the end of this email.
> > 
> > AFAICT, this solution can't fix the bug discussed in 2/2.  Moreover
> > (and looking back at (A-B)), we observe that:
> > 
> >   1) locking in the channel callback is not quite as desired: we'd
> >      want a request_addr_callback_nolock() say and 'protected' it
> >      together with ->completion_func();
> 
> I'm not understanding this point.  Could you clarify?

Basically (on top of the previous diff):

@@ -2700,6 +2725,7 @@ static void hv_pci_onchannelcallback(void *context)
 	int ret;
 	struct hv_pcibus_device *hbus = context;
 	struct vmbus_channel *chan = hbus->hdev->channel;
+	struct vmbus_requestor *rqstor = &chan->requestor;
 	u32 bytes_recvd;
 	u64 req_id, req_addr;
 	struct vmpacket_descriptor *desc;
@@ -2713,6 +2739,7 @@ static void hv_pci_onchannelcallback(void *context)
 	struct pci_dev_inval_block *inval;
 	struct pci_dev_incoming *dev_message;
 	struct hv_pci_dev *hpdev;
+	unsigned long flags;
 
 	buffer = kmalloc(bufferlen, GFP_ATOMIC);
 	if (!buffer)
@@ -2747,8 +2774,10 @@ static void hv_pci_onchannelcallback(void *context)
 		switch (desc->type) {
 		case VM_PKT_COMP:
 
-			req_addr = chan->request_addr_callback(chan, req_id);
+			spin_lock_irqsave(&rqstor->req_lock, flags);
+			req_addr = __hv_pci_request_addr(chan, req_id);
 			if (!req_addr || req_addr == VMBUS_RQST_ERROR) {
+				spin_unlock_irqrestore(&rqstor->req_lock, flags);
 				dev_warn_ratelimited(&hbus->hdev->device,
 						     "Invalid request ID\n");
 				break;
@@ -2758,6 +2787,7 @@ static void hv_pci_onchannelcallback(void *context)
 			comp_packet->completion_func(comp_packet->compl_ctxt,
 						     response,
 						     bytes_recvd);
+			spin_unlock_irqrestore(&rqstor->req_lock, flags);
 			break;
 
 		case VM_PKT_DATA_INBAND:


where I renamed request_addr_callback_nolock() to __hv_pci_request_addr()
(this being as in vmbus_request_addr() but without the requestor lock).
A "locked" callback would still be wanted and used in, e.g., the failure
path of hv_ringbuffer_write().


> >   2) hv_compose_msi_msg() doesn't know the value of the request ID
> >      it has allocated (hv_compose_msi_msg() -> vmbus_sendpacket();
> >      cf. also remove_request_id() in the current submission).
> 
> Agreed.  This would have to be addressed by adding another version of
> vmbus_sendpacket() that returns the request ID.

Indeed...  Some care would be needed to make sure that that "ID removal"
can't "race" with hv_pci_onchannelcallback() (which could have removed
the ID now), but yes...


> > Hope this helps clarify the problems at stake, and move forward to a
> > 'final' solution...
> 
> I think there's a reasonable way for the vmbus_next_request_id()
> mechanism to solve the problem in Patch 2/2 (if a new version of
> vmbus_sendpacket is added).  To me, that mechanism seems safer
> in that it restricts the completion function to running just once
> per requestID.  With this patch, we must remember that the
> completion functions must remain idempotent.

Fair enough.  Thank you for bearing with me and patiently reviewing these
matters.  Working out the details...

Thanks,
  Andrea
diff mbox series

Patch

diff --git a/drivers/pci/controller/pci-hyperv.c b/drivers/pci/controller/pci-hyperv.c
index ae0bc2fee4ca8..fbc62aab08fdc 100644
--- a/drivers/pci/controller/pci-hyperv.c
+++ b/drivers/pci/controller/pci-hyperv.c
@@ -495,6 +495,9 @@  struct hv_pcibus_device {
 	spinlock_t device_list_lock;	/* Protect lists below */
 	void __iomem *cfg_addr;
 
+	spinlock_t idr_lock; /* Serialize accesses to the IDR */
+	struct idr idr; /* Map guest memory addresses */
+
 	struct list_head children;
 	struct list_head dr_list;
 
@@ -1208,6 +1211,27 @@  static void hv_pci_read_config_compl(void *context, struct pci_response *resp,
 	complete(&comp->comp_pkt.host_event);
 }
 
+static inline int alloc_request_id(struct hv_pcibus_device *hbus,
+				   void *ptr, gfp_t gfp)
+{
+	unsigned long flags;
+	int req_id;
+
+	spin_lock_irqsave(&hbus->idr_lock, flags);
+	req_id = idr_alloc(&hbus->idr, ptr, 1, 0, gfp);
+	spin_unlock_irqrestore(&hbus->idr_lock, flags);
+	return req_id;
+}
+
+static inline void remove_request_id(struct hv_pcibus_device *hbus, int req_id)
+{
+	unsigned long flags;
+
+	spin_lock_irqsave(&hbus->idr_lock, flags);
+	idr_remove(&hbus->idr, req_id);
+	spin_unlock_irqrestore(&hbus->idr_lock, flags);
+}
+
 /**
  * hv_read_config_block() - Sends a read config block request to
  * the back-end driver running in the Hyper-V parent partition.
@@ -1232,7 +1256,7 @@  static int hv_read_config_block(struct pci_dev *pdev, void *buf,
 	} pkt;
 	struct hv_read_config_compl comp_pkt;
 	struct pci_read_block *read_blk;
-	int ret;
+	int req_id, ret;
 
 	if (len == 0 || len > HV_CONFIG_BLOCK_SIZE_MAX)
 		return -EINVAL;
@@ -1250,16 +1274,19 @@  static int hv_read_config_block(struct pci_dev *pdev, void *buf,
 	read_blk->block_id = block_id;
 	read_blk->bytes_requested = len;
 
+	req_id = alloc_request_id(hbus, &pkt.pkt, GFP_KERNEL);
+	if (req_id < 0)
+		return req_id;
+
 	ret = vmbus_sendpacket(hbus->hdev->channel, read_blk,
-			       sizeof(*read_blk), (unsigned long)&pkt.pkt,
-			       VM_PKT_DATA_INBAND,
+			       sizeof(*read_blk), req_id, VM_PKT_DATA_INBAND,
 			       VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
 	if (ret)
-		return ret;
+		goto exit;
 
 	ret = wait_for_response(hbus->hdev, &comp_pkt.comp_pkt.host_event);
 	if (ret)
-		return ret;
+		goto exit;
 
 	if (comp_pkt.comp_pkt.completion_status != 0 ||
 	    comp_pkt.bytes_returned == 0) {
@@ -1267,11 +1294,14 @@  static int hv_read_config_block(struct pci_dev *pdev, void *buf,
 			"Read Config Block failed: 0x%x, bytes_returned=%d\n",
 			comp_pkt.comp_pkt.completion_status,
 			comp_pkt.bytes_returned);
-		return -EIO;
+		ret = -EIO;
+		goto exit;
 	}
 
 	*bytes_returned = comp_pkt.bytes_returned;
-	return 0;
+exit:
+	remove_request_id(hbus, req_id);
+	return ret;
 }
 
 /**
@@ -1313,8 +1343,8 @@  static int hv_write_config_block(struct pci_dev *pdev, void *buf,
 	} pkt;
 	struct hv_pci_compl comp_pkt;
 	struct pci_write_block *write_blk;
+	int req_id, ret;
 	u32 pkt_size;
-	int ret;
 
 	if (len == 0 || len > HV_CONFIG_BLOCK_SIZE_MAX)
 		return -EINVAL;
@@ -1340,24 +1370,30 @@  static int hv_write_config_block(struct pci_dev *pdev, void *buf,
 	 */
 	pkt_size += sizeof(pkt.reserved);
 
+	req_id = alloc_request_id(hbus, &pkt.pkt, GFP_KERNEL);
+	if (req_id < 0)
+		return req_id;
+
 	ret = vmbus_sendpacket(hbus->hdev->channel, write_blk, pkt_size,
-			       (unsigned long)&pkt.pkt, VM_PKT_DATA_INBAND,
+			       req_id, VM_PKT_DATA_INBAND,
 			       VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
 	if (ret)
-		return ret;
+		goto exit;
 
 	ret = wait_for_response(hbus->hdev, &comp_pkt.host_event);
 	if (ret)
-		return ret;
+		goto exit;
 
 	if (comp_pkt.completion_status != 0) {
 		dev_err(&hbus->hdev->device,
 			"Write Config Block failed: 0x%x\n",
 			comp_pkt.completion_status);
-		return -EIO;
+		ret = -EIO;
 	}
 
-	return 0;
+exit:
+	remove_request_id(hbus, req_id);
+	return ret;
 }
 
 /**
@@ -1407,7 +1443,7 @@  static void hv_int_desc_free(struct hv_pci_dev *hpdev,
 	int_pkt->wslot.slot = hpdev->desc.win_slot.slot;
 	int_pkt->int_desc = *int_desc;
 	vmbus_sendpacket(hpdev->hbus->hdev->channel, int_pkt, sizeof(*int_pkt),
-			 (unsigned long)&ctxt.pkt, VM_PKT_DATA_INBAND, 0);
+			 0, VM_PKT_DATA_INBAND, 0);
 	kfree(int_desc);
 }
 
@@ -1688,9 +1724,8 @@  static void hv_compose_msi_msg(struct irq_data *data, struct msi_msg *msg)
 			struct pci_create_interrupt3 v3;
 		} int_pkts;
 	} __packed ctxt;
-
+	int req_id, ret;
 	u32 size;
-	int ret;
 
 	pdev = msi_desc_to_pci_dev(irq_data_get_msi_desc(data));
 	dest = irq_data_get_effective_affinity_mask(data);
@@ -1750,15 +1785,18 @@  static void hv_compose_msi_msg(struct irq_data *data, struct msi_msg *msg)
 		goto free_int_desc;
 	}
 
+	req_id = alloc_request_id(hbus, &ctxt.pci_pkt, GFP_ATOMIC);
+	if (req_id < 0)
+		goto free_int_desc;
+
 	ret = vmbus_sendpacket(hpdev->hbus->hdev->channel, &ctxt.int_pkts,
-			       size, (unsigned long)&ctxt.pci_pkt,
-			       VM_PKT_DATA_INBAND,
+			       size, req_id, VM_PKT_DATA_INBAND,
 			       VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
 	if (ret) {
 		dev_err(&hbus->hdev->device,
 			"Sending request for interrupt failed: 0x%x",
 			comp.comp_pkt.completion_status);
-		goto free_int_desc;
+		goto remove_id;
 	}
 
 	/*
@@ -1811,7 +1849,7 @@  static void hv_compose_msi_msg(struct irq_data *data, struct msi_msg *msg)
 		dev_err(&hbus->hdev->device,
 			"Request for interrupt failed: 0x%x",
 			comp.comp_pkt.completion_status);
-		goto free_int_desc;
+		goto remove_id;
 	}
 
 	/*
@@ -1827,11 +1865,14 @@  static void hv_compose_msi_msg(struct irq_data *data, struct msi_msg *msg)
 	msg->address_lo = comp.int_desc.address & 0xffffffff;
 	msg->data = comp.int_desc.data;
 
+	remove_request_id(hbus, req_id);
 	put_pcichild(hpdev);
 	return;
 
 enable_tasklet:
 	tasklet_enable(&channel->callback_event);
+remove_id:
+	remove_request_id(hbus, req_id);
 free_int_desc:
 	kfree(int_desc);
 drop_reference:
@@ -2258,7 +2299,7 @@  static struct hv_pci_dev *new_pcichild_device(struct hv_pcibus_device *hbus,
 		u8 buffer[sizeof(struct pci_child_message)];
 	} pkt;
 	unsigned long flags;
-	int ret;
+	int req_id, ret;
 
 	hpdev = kzalloc(sizeof(*hpdev), GFP_KERNEL);
 	if (!hpdev)
@@ -2275,16 +2316,19 @@  static struct hv_pci_dev *new_pcichild_device(struct hv_pcibus_device *hbus,
 	res_req->message_type.type = PCI_QUERY_RESOURCE_REQUIREMENTS;
 	res_req->wslot.slot = desc->win_slot.slot;
 
+	req_id = alloc_request_id(hbus, &pkt.init_packet, GFP_KERNEL);
+	if (req_id < 0)
+		goto error;
+
 	ret = vmbus_sendpacket(hbus->hdev->channel, res_req,
-			       sizeof(struct pci_child_message),
-			       (unsigned long)&pkt.init_packet,
+			       sizeof(struct pci_child_message), req_id,
 			       VM_PKT_DATA_INBAND,
 			       VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
 	if (ret)
-		goto error;
+		goto remove_id;
 
 	if (wait_for_response(hbus->hdev, &comp_pkt.host_event))
-		goto error;
+		goto remove_id;
 
 	hpdev->desc = *desc;
 	refcount_set(&hpdev->refs, 1);
@@ -2293,8 +2337,11 @@  static struct hv_pci_dev *new_pcichild_device(struct hv_pcibus_device *hbus,
 
 	list_add_tail(&hpdev->list_entry, &hbus->children);
 	spin_unlock_irqrestore(&hbus->device_list_lock, flags);
+	remove_request_id(hbus, req_id);
 	return hpdev;
 
+remove_id:
+	remove_request_id(hbus, req_id);
 error:
 	kfree(hpdev);
 	return NULL;
@@ -2648,8 +2695,7 @@  static void hv_eject_device_work(struct work_struct *work)
 	ejct_pkt = (struct pci_eject_response *)&ctxt.pkt.message;
 	ejct_pkt->message_type.type = PCI_EJECTION_COMPLETE;
 	ejct_pkt->wslot.slot = hpdev->desc.win_slot.slot;
-	vmbus_sendpacket(hbus->hdev->channel, ejct_pkt,
-			 sizeof(*ejct_pkt), (unsigned long)&ctxt.pkt,
+	vmbus_sendpacket(hbus->hdev->channel, ejct_pkt, sizeof(*ejct_pkt), 0,
 			 VM_PKT_DATA_INBAND, 0);
 
 	/* For the get_pcichild() in hv_pci_eject_device() */
@@ -2709,6 +2755,7 @@  static void hv_pci_onchannelcallback(void *context)
 	struct pci_dev_inval_block *inval;
 	struct pci_dev_incoming *dev_message;
 	struct hv_pci_dev *hpdev;
+	unsigned long flags;
 
 	buffer = kmalloc(bufferlen, GFP_ATOMIC);
 	if (!buffer)
@@ -2743,11 +2790,19 @@  static void hv_pci_onchannelcallback(void *context)
 		switch (desc->type) {
 		case VM_PKT_COMP:
 
-			/*
-			 * The host is trusted, and thus it's safe to interpret
-			 * this transaction ID as a pointer.
-			 */
-			comp_packet = (struct pci_packet *)req_id;
+			if (req_id > INT_MAX) {
+				dev_err_ratelimited(&hbus->hdev->device,
+						    "Request ID > INT_MAX\n");
+				break;
+			}
+			spin_lock_irqsave(&hbus->idr_lock, flags);
+			comp_packet = (struct pci_packet *)idr_find(&hbus->idr, req_id);
+			spin_unlock_irqrestore(&hbus->idr_lock, flags);
+			if (!comp_packet) {
+				dev_warn_ratelimited(&hbus->hdev->device,
+						     "Request ID not found\n");
+				break;
+			}
 			response = (struct pci_response *)buffer;
 			comp_packet->completion_func(comp_packet->compl_ctxt,
 						     response,
@@ -2858,8 +2913,7 @@  static int hv_pci_protocol_negotiation(struct hv_device *hdev,
 	struct pci_version_request *version_req;
 	struct hv_pci_compl comp_pkt;
 	struct pci_packet *pkt;
-	int ret;
-	int i;
+	int req_id, ret, i;
 
 	/*
 	 * Initiate the handshake with the host and negotiate
@@ -2877,12 +2931,18 @@  static int hv_pci_protocol_negotiation(struct hv_device *hdev,
 	version_req = (struct pci_version_request *)&pkt->message;
 	version_req->message_type.type = PCI_QUERY_PROTOCOL_VERSION;
 
+	req_id = alloc_request_id(hbus, pkt, GFP_KERNEL);
+	if (req_id < 0) {
+		kfree(pkt);
+		return req_id;
+	}
+
 	for (i = 0; i < num_version; i++) {
 		version_req->protocol_version = version[i];
 		ret = vmbus_sendpacket(hdev->channel, version_req,
-				sizeof(struct pci_version_request),
-				(unsigned long)pkt, VM_PKT_DATA_INBAND,
-				VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
+				       sizeof(struct pci_version_request),
+				       req_id, VM_PKT_DATA_INBAND,
+				       VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
 		if (!ret)
 			ret = wait_for_response(hdev, &comp_pkt.host_event);
 
@@ -2917,6 +2977,7 @@  static int hv_pci_protocol_negotiation(struct hv_device *hdev,
 	ret = -EPROTO;
 
 exit:
+	remove_request_id(hbus, req_id);
 	kfree(pkt);
 	return ret;
 }
@@ -3079,7 +3140,7 @@  static int hv_pci_enter_d0(struct hv_device *hdev)
 	struct pci_bus_d0_entry *d0_entry;
 	struct hv_pci_compl comp_pkt;
 	struct pci_packet *pkt;
-	int ret;
+	int req_id, ret;
 
 	/*
 	 * Tell the host that the bus is ready to use, and moved into the
@@ -3098,8 +3159,14 @@  static int hv_pci_enter_d0(struct hv_device *hdev)
 	d0_entry->message_type.type = PCI_BUS_D0ENTRY;
 	d0_entry->mmio_base = hbus->mem_config->start;
 
+	req_id = alloc_request_id(hbus, pkt, GFP_KERNEL);
+	if (req_id < 0) {
+		kfree(pkt);
+		return req_id;
+	}
+
 	ret = vmbus_sendpacket(hdev->channel, d0_entry, sizeof(*d0_entry),
-			       (unsigned long)pkt, VM_PKT_DATA_INBAND,
+			       req_id, VM_PKT_DATA_INBAND,
 			       VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
 	if (!ret)
 		ret = wait_for_response(hdev, &comp_pkt.host_event);
@@ -3112,12 +3179,10 @@  static int hv_pci_enter_d0(struct hv_device *hdev)
 			"PCI Pass-through VSP failed D0 Entry with status %x\n",
 			comp_pkt.completion_status);
 		ret = -EPROTO;
-		goto exit;
 	}
 
-	ret = 0;
-
 exit:
+	remove_request_id(hbus, req_id);
 	kfree(pkt);
 	return ret;
 }
@@ -3175,11 +3240,10 @@  static int hv_send_resources_allocated(struct hv_device *hdev)
 	struct pci_resources_assigned *res_assigned;
 	struct pci_resources_assigned2 *res_assigned2;
 	struct hv_pci_compl comp_pkt;
+	int wslot, req_id, ret = 0;
 	struct hv_pci_dev *hpdev;
 	struct pci_packet *pkt;
 	size_t size_res;
-	int wslot;
-	int ret;
 
 	size_res = (hbus->protocol_version < PCI_PROTOCOL_VERSION_1_2)
 			? sizeof(*res_assigned) : sizeof(*res_assigned2);
@@ -3188,7 +3252,11 @@  static int hv_send_resources_allocated(struct hv_device *hdev)
 	if (!pkt)
 		return -ENOMEM;
 
-	ret = 0;
+	req_id = alloc_request_id(hbus, pkt, GFP_KERNEL);
+	if (req_id < 0) {
+		kfree(pkt);
+		return req_id;
+	}
 
 	for (wslot = 0; wslot < 256; wslot++) {
 		hpdev = get_pcichild_wslot(hbus, wslot);
@@ -3215,10 +3283,9 @@  static int hv_send_resources_allocated(struct hv_device *hdev)
 		}
 		put_pcichild(hpdev);
 
-		ret = vmbus_sendpacket(hdev->channel, &pkt->message,
-				size_res, (unsigned long)pkt,
-				VM_PKT_DATA_INBAND,
-				VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
+		ret = vmbus_sendpacket(hdev->channel, &pkt->message, size_res,
+				       req_id, VM_PKT_DATA_INBAND,
+				       VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
 		if (!ret)
 			ret = wait_for_response(hdev, &comp_pkt.host_event);
 		if (ret)
@@ -3235,6 +3302,7 @@  static int hv_send_resources_allocated(struct hv_device *hdev)
 		hbus->wslot_res_allocated = wslot;
 	}
 
+	remove_request_id(hbus, req_id);
 	kfree(pkt);
 	return ret;
 }
@@ -3412,6 +3480,7 @@  static int hv_pci_probe(struct hv_device *hdev,
 	spin_lock_init(&hbus->config_lock);
 	spin_lock_init(&hbus->device_list_lock);
 	spin_lock_init(&hbus->retarget_msi_interrupt_lock);
+	spin_lock_init(&hbus->idr_lock);
 	hbus->wq = alloc_ordered_workqueue("hv_pci_%x", 0,
 					   hbus->bridge->domain_nr);
 	if (!hbus->wq) {
@@ -3419,6 +3488,7 @@  static int hv_pci_probe(struct hv_device *hdev,
 		goto free_dom;
 	}
 
+	idr_init(&hbus->idr);
 	ret = vmbus_open(hdev->channel, pci_ring_size, pci_ring_size, NULL, 0,
 			 hv_pci_onchannelcallback, hbus);
 	if (ret)
@@ -3537,6 +3607,7 @@  static int hv_pci_probe(struct hv_device *hdev,
 	hv_free_config_window(hbus);
 close:
 	vmbus_close(hdev->channel);
+	idr_destroy(&hbus->idr);
 destroy_wq:
 	destroy_workqueue(hbus->wq);
 free_dom:
@@ -3556,7 +3627,7 @@  static int hv_pci_bus_exit(struct hv_device *hdev, bool keep_devs)
 	struct hv_pci_compl comp_pkt;
 	struct hv_pci_dev *hpdev, *tmp;
 	unsigned long flags;
-	int ret;
+	int req_id, ret;
 
 	/*
 	 * After the host sends the RESCIND_CHANNEL message, it doesn't
@@ -3599,18 +3670,23 @@  static int hv_pci_bus_exit(struct hv_device *hdev, bool keep_devs)
 	pkt.teardown_packet.compl_ctxt = &comp_pkt;
 	pkt.teardown_packet.message[0].type = PCI_BUS_D0EXIT;
 
+	req_id = alloc_request_id(hbus, &pkt.teardown_packet, GFP_KERNEL);
+	if (req_id < 0)
+		return req_id;
+
 	ret = vmbus_sendpacket(hdev->channel, &pkt.teardown_packet.message,
-			       sizeof(struct pci_message),
-			       (unsigned long)&pkt.teardown_packet,
+			       sizeof(struct pci_message), req_id,
 			       VM_PKT_DATA_INBAND,
 			       VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
 	if (ret)
-		return ret;
+		goto exit;
 
 	if (wait_for_completion_timeout(&comp_pkt.host_event, 10 * HZ) == 0)
-		return -ETIMEDOUT;
+		ret = -ETIMEDOUT;
 
-	return 0;
+exit:
+	remove_request_id(hbus, req_id);
+	return ret;
 }
 
 /**
@@ -3648,6 +3724,7 @@  static int hv_pci_remove(struct hv_device *hdev)
 	ret = hv_pci_bus_exit(hdev, false);
 
 	vmbus_close(hdev->channel);
+	idr_destroy(&hbus->idr);
 
 	iounmap(hbus->cfg_addr);
 	hv_free_config_window(hbus);
@@ -3704,6 +3781,7 @@  static int hv_pci_suspend(struct hv_device *hdev)
 		return ret;
 
 	vmbus_close(hdev->channel);
+	idr_destroy(&hbus->idr);
 
 	return 0;
 }
@@ -3749,6 +3827,7 @@  static int hv_pci_resume(struct hv_device *hdev)
 
 	hbus->state = hv_pcibus_init;
 
+	idr_init(&hbus->idr);
 	ret = vmbus_open(hdev->channel, pci_ring_size, pci_ring_size, NULL, 0,
 			 hv_pci_onchannelcallback, hbus);
 	if (ret)
@@ -3780,6 +3859,7 @@  static int hv_pci_resume(struct hv_device *hdev)
 	return 0;
 out:
 	vmbus_close(hdev->channel);
+	idr_destroy(&hbus->idr);
 	return ret;
 }