@@ -429,6 +429,14 @@ static int vq_access_ok(unsigned int num,
struct vring_avail __user *avail,
struct vring_used __user *used)
{
+
+ if (num > UINT_MAX / sizeof *desc)
+ return 0;
+ if (num > UINT_MAX / sizeof *avail->ring - sizeof *avail)
+ return 0;
+ if (num > UINT_MAX / sizeof *used->ring - sizeof *used)
+ return 0;
+
return access_ok(VERIFY_READ, desc, num * sizeof *desc) &&
access_ok(VERIFY_READ, avail,
sizeof *avail + num * sizeof *avail->ring) &&
@@ -447,6 +455,9 @@ int vhost_log_access_ok(struct vhost_dev *dev)
/* Caller should have vq mutex and device mutex */
static int vq_log_access_ok(struct vhost_virtqueue *vq, void __user *log_base)
{
+ if (vq->num > UINT_MAX / sizeof *vq->used->ring - sizeof *vq->used)
+ return 0;
+
return vq_memory_access_ok(log_base, vq->dev->memory,
vhost_has_feature(vq->dev, VHOST_F_LOG_ALL)) &&
(!vq->log_used || log_access_ok(log_base, vq->log_addr,
@@ -606,12 +617,17 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
}
/* Also validate log access for used ring if enabled. */
- if ((a.flags & (0x1 << VHOST_VRING_F_LOG)) &&
- !log_access_ok(vq->log_base, a.log_guest_addr,
+ if (a.flags & (0x1 << VHOST_VRING_F_LOG)) {
+ if (vq->num > UINT_MAX / sizeof *vq->used->ring - sizeof *vq->used) {
+ r = -EINVAL;
+ break;
+ }
+ if (!log_access_ok(vq->log_base, a.log_guest_addr,
sizeof *vq->used +
vq->num * sizeof *vq->used->ring)) {
- r = -EINVAL;
- break;
+ r = -EINVAL;
+ break;
+ }
}
}