@@ -53,7 +53,7 @@ static struct iommu_domain *get_domain_for_iopf(struct device *dev,
/**
* iommu_handle_iopf - IO Page Fault handler
* @fault: fault event
- * @dev: struct device.
+ * @iopf_param: the fault parameter of the device.
*
* Add a fault to the device workqueue, to be handled by mm.
*
@@ -90,29 +90,21 @@ static struct iommu_domain *get_domain_for_iopf(struct device *dev,
*
* Return: 0 on success and <0 on error.
*/
-static int iommu_handle_iopf(struct iommu_fault *fault, struct device *dev)
+static int iommu_handle_iopf(struct iommu_fault *fault,
+ struct iommu_fault_param *iopf_param)
{
int ret;
struct iopf_group *group;
struct iommu_domain *domain;
struct iopf_fault *iopf, *next;
- struct iommu_fault_param *iopf_param;
- struct dev_iommu *param = dev->iommu;
+ struct device *dev = iopf_param->dev;
- lockdep_assert_held(¶m->lock);
+ lockdep_assert_held(&iopf_param->lock);
if (fault->type != IOMMU_FAULT_PAGE_REQ)
/* Not a recoverable page fault */
return -EOPNOTSUPP;
- /*
- * As long as we're holding param->lock, the queue can't be unlinked
- * from the device and therefore cannot disappear.
- */
- iopf_param = param->fault_param;
- if (!iopf_param)
- return -ENODEV;
-
if (!(fault->prm.flags & IOMMU_FAULT_PAGE_REQUEST_LAST_PAGE)) {
iopf = kzalloc(sizeof(*iopf), GFP_KERNEL);
if (!iopf)
@@ -186,18 +178,19 @@ static int iommu_handle_iopf(struct iommu_fault *fault, struct device *dev)
*/
int iommu_report_device_fault(struct device *dev, struct iopf_fault *evt)
{
- struct dev_iommu *param = dev->iommu;
+ struct iommu_fault_param *fault_param;
struct iopf_fault *evt_pending = NULL;
- struct iommu_fault_param *fparam;
+ struct dev_iommu *param = dev->iommu;
int ret = 0;
- if (!param || !evt)
- return -EINVAL;
-
- /* we only report device fault if there is a handler registered */
mutex_lock(¶m->lock);
- fparam = param->fault_param;
+ fault_param = param->fault_param;
+ if (!fault_param) {
+ mutex_unlock(¶m->lock);
+ return -EINVAL;
+ }
+ mutex_lock(&fault_param->lock);
if (evt->fault.type == IOMMU_FAULT_PAGE_REQ &&
(evt->fault.prm.flags & IOMMU_FAULT_PAGE_REQUEST_LAST_PAGE)) {
evt_pending = kmemdup(evt, sizeof(struct iopf_fault),
@@ -206,20 +199,18 @@ int iommu_report_device_fault(struct device *dev, struct iopf_fault *evt)
ret = -ENOMEM;
goto done_unlock;
}
- mutex_lock(&fparam->lock);
- list_add_tail(&evt_pending->list, &fparam->faults);
- mutex_unlock(&fparam->lock);
+ list_add_tail(&evt_pending->list, &fault_param->faults);
}
- ret = iommu_handle_iopf(&evt->fault, dev);
+ ret = iommu_handle_iopf(&evt->fault, fault_param);
if (ret && evt_pending) {
- mutex_lock(&fparam->lock);
list_del(&evt_pending->list);
- mutex_unlock(&fparam->lock);
kfree(evt_pending);
}
done_unlock:
+ mutex_unlock(&fault_param->lock);
mutex_unlock(¶m->lock);
+
return ret;
}
EXPORT_SYMBOL_GPL(iommu_report_device_fault);
@@ -232,18 +223,23 @@ int iommu_page_response(struct device *dev,
struct iopf_fault *evt;
struct iommu_fault_page_request *prm;
struct dev_iommu *param = dev->iommu;
+ struct iommu_fault_param *fault_param;
const struct iommu_ops *ops = dev_iommu_ops(dev);
bool has_pasid = msg->flags & IOMMU_PAGE_RESP_PASID_VALID;
if (!ops->page_response)
return -ENODEV;
- if (!param || !param->fault_param)
+ mutex_lock(¶m->lock);
+ fault_param = param->fault_param;
+ if (!fault_param) {
+ mutex_unlock(¶m->lock);
return -EINVAL;
+ }
/* Only send response if there is a fault report pending */
- mutex_lock(¶m->fault_param->lock);
- if (list_empty(¶m->fault_param->faults)) {
+ mutex_lock(&fault_param->lock);
+ if (list_empty(&fault_param->faults)) {
dev_warn_ratelimited(dev, "no pending PRQ, drop response\n");
goto done_unlock;
}
@@ -251,7 +247,7 @@ int iommu_page_response(struct device *dev,
* Check if we have a matching page request pending to respond,
* otherwise return -EINVAL
*/
- list_for_each_entry(evt, ¶m->fault_param->faults, list) {
+ list_for_each_entry(evt, &fault_param->faults, list) {
prm = &evt->fault.prm;
if (prm->grpid != msg->grpid)
continue;
@@ -279,7 +275,8 @@ int iommu_page_response(struct device *dev,
}
done_unlock:
- mutex_unlock(¶m->fault_param->lock);
+ mutex_unlock(&fault_param->lock);
+ mutex_unlock(¶m->lock);
return ret;
}
EXPORT_SYMBOL_GPL(iommu_page_response);
@@ -362,11 +359,13 @@ int iopf_queue_discard_partial(struct iopf_queue *queue)
mutex_lock(&queue->lock);
list_for_each_entry(iopf_param, &queue->devices, queue_list) {
+ mutex_lock(&iopf_param->lock);
list_for_each_entry_safe(iopf, next, &iopf_param->partial,
list) {
list_del(&iopf->list);
kfree(iopf);
}
+ mutex_unlock(&iopf_param->lock);
}
mutex_unlock(&queue->lock);
return 0;