@@ -597,6 +597,8 @@ struct iommu_device {
/**
* struct iommu_fault_param - per-device IOMMU fault data
* @lock: protect pending faults list
+ * @users: user counter to manage the lifetime of the data, this field
+ * is protected by dev->iommu->lock.
* @dev: the device that owns this param
* @queue: IOPF queue
* @queue_list: index into queue->devices
@@ -606,6 +608,7 @@ struct iommu_device {
*/
struct iommu_fault_param {
struct mutex lock;
+ int users;
struct device *dev;
struct iopf_queue *queue;
@@ -26,6 +26,49 @@ void iopf_free_group(struct iopf_group *group)
}
EXPORT_SYMBOL_GPL(iopf_free_group);
+/*
+ * Return the fault parameter of a device if it exists. Otherwise, return NULL.
+ * On a successful return, the caller takes a reference of this parameter and
+ * should put it after use by calling iopf_put_dev_fault_param().
+ */
+static struct iommu_fault_param *iopf_get_dev_fault_param(struct device *dev)
+{
+ struct dev_iommu *param = dev->iommu;
+ struct iommu_fault_param *fault_param;
+
+ if (!param)
+ return NULL;
+
+ mutex_lock(¶m->lock);
+ fault_param = param->fault_param;
+ if (fault_param)
+ fault_param->users++;
+ mutex_unlock(¶m->lock);
+
+ return fault_param;
+}
+
+/* Caller must hold a reference of the fault parameter. */
+static void iopf_put_dev_fault_param(struct iommu_fault_param *fault_param)
+{
+ struct device *dev = fault_param->dev;
+ struct dev_iommu *param = dev->iommu;
+
+ mutex_lock(¶m->lock);
+ if (WARN_ON(fault_param->users <= 0 ||
+ fault_param != param->fault_param)) {
+ mutex_unlock(¶m->lock);
+ return;
+ }
+
+ if (--fault_param->users == 0) {
+ param->fault_param = NULL;
+ kfree(fault_param);
+ put_device(dev);
+ }
+ mutex_unlock(¶m->lock);
+}
+
/**
* iommu_handle_iopf - IO Page Fault handler
* @fault: fault event
@@ -72,23 +115,14 @@ static int iommu_handle_iopf(struct iommu_fault *fault, struct device *dev)
struct iopf_group *group;
struct iopf_fault *iopf, *next;
struct iommu_domain *domain = NULL;
- struct iommu_fault_param *iopf_param;
- struct dev_iommu *param = dev->iommu;
+ struct iommu_fault_param *iopf_param = dev->iommu->fault_param;
- lockdep_assert_held(¶m->lock);
+ lockdep_assert_held(&iopf_param->lock);
if (fault->type != IOMMU_FAULT_PAGE_REQ)
/* Not a recoverable page fault */
return -EOPNOTSUPP;
- /*
- * As long as we're holding param->lock, the queue can't be unlinked
- * from the device and therefore cannot disappear.
- */
- iopf_param = param->fault_param;
- if (!iopf_param)
- return -ENODEV;
-
if (!(fault->prm.flags & IOMMU_FAULT_PAGE_REQUEST_LAST_PAGE)) {
iopf = kzalloc(sizeof(*iopf), GFP_KERNEL);
if (!iopf)
@@ -173,18 +207,15 @@ static int iommu_handle_iopf(struct iommu_fault *fault, struct device *dev)
*/
int iommu_report_device_fault(struct device *dev, struct iopf_fault *evt)
{
- struct dev_iommu *param = dev->iommu;
+ struct iommu_fault_param *fault_param;
struct iopf_fault *evt_pending = NULL;
- struct iommu_fault_param *fparam;
int ret = 0;
- if (!param || !evt)
+ fault_param = iopf_get_dev_fault_param(dev);
+ if (!fault_param)
return -EINVAL;
- /* we only report device fault if there is a handler registered */
- mutex_lock(¶m->lock);
- fparam = param->fault_param;
-
+ mutex_lock(&fault_param->lock);
if (evt->fault.type == IOMMU_FAULT_PAGE_REQ &&
(evt->fault.prm.flags & IOMMU_FAULT_PAGE_REQUEST_LAST_PAGE)) {
evt_pending = kmemdup(evt, sizeof(struct iopf_fault),
@@ -193,20 +224,18 @@ int iommu_report_device_fault(struct device *dev, struct iopf_fault *evt)
ret = -ENOMEM;
goto done_unlock;
}
- mutex_lock(&fparam->lock);
- list_add_tail(&evt_pending->list, &fparam->faults);
- mutex_unlock(&fparam->lock);
+ list_add_tail(&evt_pending->list, &fault_param->faults);
}
ret = iommu_handle_iopf(&evt->fault, dev);
if (ret && evt_pending) {
- mutex_lock(&fparam->lock);
list_del(&evt_pending->list);
- mutex_unlock(&fparam->lock);
kfree(evt_pending);
}
done_unlock:
- mutex_unlock(¶m->lock);
+ mutex_unlock(&fault_param->lock);
+ iopf_put_dev_fault_param(fault_param);
+
return ret;
}
EXPORT_SYMBOL_GPL(iommu_report_device_fault);
@@ -218,19 +247,20 @@ int iommu_page_response(struct device *dev,
int ret = -EINVAL;
struct iopf_fault *evt;
struct iommu_fault_page_request *prm;
- struct dev_iommu *param = dev->iommu;
+ struct iommu_fault_param *fault_param;
const struct iommu_ops *ops = dev_iommu_ops(dev);
bool has_pasid = msg->flags & IOMMU_PAGE_RESP_PASID_VALID;
if (!ops->page_response)
return -ENODEV;
- if (!param || !param->fault_param)
- return -EINVAL;
+ fault_param = iopf_get_dev_fault_param(dev);
+ if (!fault_param)
+ return -ENODEV;
/* Only send response if there is a fault report pending */
- mutex_lock(¶m->fault_param->lock);
- if (list_empty(¶m->fault_param->faults)) {
+ mutex_lock(&fault_param->lock);
+ if (list_empty(&fault_param->faults)) {
dev_warn_ratelimited(dev, "no pending PRQ, drop response\n");
goto done_unlock;
}
@@ -238,7 +268,7 @@ int iommu_page_response(struct device *dev,
* Check if we have a matching page request pending to respond,
* otherwise return -EINVAL
*/
- list_for_each_entry(evt, ¶m->fault_param->faults, list) {
+ list_for_each_entry(evt, &fault_param->faults, list) {
prm = &evt->fault.prm;
if (prm->grpid != msg->grpid)
continue;
@@ -266,7 +296,9 @@ int iommu_page_response(struct device *dev,
}
done_unlock:
- mutex_unlock(¶m->fault_param->lock);
+ mutex_unlock(&fault_param->lock);
+ iopf_put_dev_fault_param(fault_param);
+
return ret;
}
EXPORT_SYMBOL_GPL(iommu_page_response);
@@ -285,22 +317,15 @@ EXPORT_SYMBOL_GPL(iommu_page_response);
*/
int iopf_queue_flush_dev(struct device *dev)
{
- int ret = 0;
- struct iommu_fault_param *iopf_param;
- struct dev_iommu *param = dev->iommu;
+ struct iommu_fault_param *iopf_param = iopf_get_dev_fault_param(dev);
- if (!param)
+ if (!iopf_param)
return -ENODEV;
- mutex_lock(¶m->lock);
- iopf_param = param->fault_param;
- if (iopf_param)
- flush_workqueue(iopf_param->queue->wq);
- else
- ret = -ENODEV;
- mutex_unlock(¶m->lock);
+ flush_workqueue(iopf_param->queue->wq);
+ iopf_put_dev_fault_param(iopf_param);
- return ret;
+ return 0;
}
EXPORT_SYMBOL_GPL(iopf_queue_flush_dev);
@@ -349,11 +374,13 @@ int iopf_queue_discard_partial(struct iopf_queue *queue)
mutex_lock(&queue->lock);
list_for_each_entry(iopf_param, &queue->devices, queue_list) {
+ mutex_lock(&iopf_param->lock);
list_for_each_entry_safe(iopf, next, &iopf_param->partial,
list) {
list_del(&iopf->list);
kfree(iopf);
}
+ mutex_unlock(&iopf_param->lock);
}
mutex_unlock(&queue->lock);
return 0;
@@ -392,6 +419,7 @@ int iopf_queue_add_device(struct iopf_queue *queue, struct device *dev)
INIT_LIST_HEAD(&fault_param->faults);
INIT_LIST_HEAD(&fault_param->partial);
fault_param->dev = dev;
+ fault_param->users = 1;
list_add(&fault_param->queue_list, &queue->devices);
fault_param->queue = queue;
@@ -444,9 +472,11 @@ int iopf_queue_remove_device(struct iopf_queue *queue, struct device *dev)
list_for_each_entry_safe(iopf, next, &fault_param->partial, list)
kfree(iopf);
- param->fault_param = NULL;
- kfree(fault_param);
- put_device(dev);
+ if (--fault_param->users == 0) {
+ param->fault_param = NULL;
+ kfree(fault_param);
+ put_device(dev);
+ }
unlock:
mutex_unlock(¶m->lock);
mutex_unlock(&queue->lock);