diff mbox series

[1/2] virtio_blk: add length check for device writable portion

Message ID 20250224233106.8519-2-mgurtovoy@nvidia.com (mailing list archive)
State New
Headers show
Series virtio: Add length checks for device writable portions | expand

Commit Message

Max Gurtovoy Feb. 24, 2025, 11:31 p.m. UTC
Add a safety check to ensure that the length of data written by the
device is at least as large the expected length. If this condition is
not met, it indicates a potential error in the device's response.

This change aligns with the virtio specification, which states:
"The driver MUST NOT make assumptions about data in device-writable
buffers beyond the first len bytes, and SHOULD ignore this data."

By setting an error status when len is insufficient, we ensure that the
driver does not process potentially invalid or incomplete data from the
device.

Reviewed-by: Aurelien Aptel <aaptel@nvidia.com>
Signed-off-by: Lokesh Arora <larora@nvidia.com>
Signed-off-by: Israel Rukshin <israelr@nvidia.com>
Signed-off-by: Max Gurtovoy <mgurtovoy@nvidia.com>
---
 drivers/block/virtio_blk.c | 20 ++++++++++++++++++++
 1 file changed, 20 insertions(+)
diff mbox series

Patch

diff --git a/drivers/block/virtio_blk.c b/drivers/block/virtio_blk.c
index 6a61ec35f426..58407cfee3ee 100644
--- a/drivers/block/virtio_blk.c
+++ b/drivers/block/virtio_blk.c
@@ -331,6 +331,20 @@  static inline u8 virtblk_vbr_status(struct virtblk_req *vbr)
 	return *((u8 *)&vbr->in_hdr + vbr->in_hdr_len - 1);
 }
 
+static inline void virtblk_vbr_set_err_status_upon_len_err(struct virtblk_req *vbr,
+		struct request *req, unsigned int len)
+{
+	unsigned int expected_len = vbr->in_hdr_len;
+
+	if (rq_dma_dir(req) == DMA_FROM_DEVICE)
+		expected_len += blk_rq_payload_bytes(req);
+
+	if (unlikely(len < expected_len)) {
+		u8 *status_ptr = (u8 *)&vbr->in_hdr + vbr->in_hdr_len - 1;
+		*status_ptr = VIRTIO_BLK_S_IOERR;
+	}
+}
+
 static inline void virtblk_request_done(struct request *req)
 {
 	struct virtblk_req *vbr = blk_mq_rq_to_pdu(req);
@@ -362,6 +376,9 @@  static void virtblk_done(struct virtqueue *vq)
 		while ((vbr = virtqueue_get_buf(vblk->vqs[qid].vq, &len)) != NULL) {
 			struct request *req = blk_mq_rq_from_pdu(vbr);
 
+			/* Check device writable portion length, and fail upon error */
+			virtblk_vbr_set_err_status_upon_len_err(vbr, req, len);
+
 			if (likely(!blk_should_fake_timeout(req->q)))
 				blk_mq_complete_request(req);
 			req_done = true;
@@ -1208,6 +1225,9 @@  static int virtblk_poll(struct blk_mq_hw_ctx *hctx, struct io_comp_batch *iob)
 	while ((vbr = virtqueue_get_buf(vq->vq, &len)) != NULL) {
 		struct request *req = blk_mq_rq_from_pdu(vbr);
 
+		/* Check device writable portion length, and fail upon error */
+		virtblk_vbr_set_err_status_upon_len_err(vbr, req, len);
+
 		found++;
 		if (!blk_mq_complete_request_remote(req) &&
 		    !blk_mq_add_to_batch(req, iob, virtblk_vbr_status(vbr),