Message ID | 20240516081202.27023-8-alucerop@amd.com |
---|---|
State | New, archived |
Headers | show |
Series | RFC: add Type2 device support | expand |
From: Alejandro Lucero <alucerop@amd.com> Prevent concurrent access to endpoint port topology. Based on: https://lore.kernel.org/linux-cxl/168592149709.1948938.8663425987110396027.stgit@dwillia2-xfh.jf.intel.com/T/#m18497367d2ae38f88e94c06369eaa83fa23e92b2 Note: I realize original patch is explaining what the code does while my explanation is ... really poor. I will use original changelog in future versions. Signed-off-by: Alejandro Lucero <alucerop@amd.com> Co-developed-by: Dan Williams <dan.j.williams@intel.com> On 5/16/24 09:11, alucerop@amd.com wrote: > From: Alejandro Lucero <alucerop@amd.com> > > Prevent concurrent access to endpoint port topology. > > Signed-off-by: Alejandro Lucero <alucerop@amd.com> > Signed-off-by: Dan Williams <dan.j.williams@intel.com> > --- > drivers/cxl/core/memdev.c | 41 +++++++++++++++++++++++++++++ > include/linux/cxlmem.h | 4 +++ > tools/testing/cxl/type2/pci_type2.c | 9 +++++++ > 3 files changed, 54 insertions(+) > > diff --git a/drivers/cxl/core/memdev.c b/drivers/cxl/core/memdev.c > index 27063cd4ea73..16e356ef5b6d 100644 > --- a/drivers/cxl/core/memdev.c > +++ b/drivers/cxl/core/memdev.c > @@ -1124,6 +1124,47 @@ struct cxl_memdev *devm_cxl_add_memdev(struct device *host, > } > EXPORT_SYMBOL_NS_GPL(devm_cxl_add_memdev, CXL); > > +/* > + * Try to get a locked reference on a memdev's CXL port topology > + * connection. Be careful to observe when cxl_mem_probe() has deposited > + * a probe deferral awaiting the arrival of the CXL root driver > +*/ > +struct cxl_port *cxl_acquire_endpoint(struct cxl_memdev *cxlmd) > +{ > + struct cxl_port *endpoint; > + int rc = -ENXIO; > + > + device_lock(&cxlmd->dev); > + endpoint = cxlmd->endpoint; > + if (!endpoint) > + goto err; > + > + if (IS_ERR(endpoint)) { > + rc = PTR_ERR(endpoint); > + goto err; > + } > + > + device_lock(&endpoint->dev); > + if (!endpoint->dev.driver) > + goto err_endpoint; > + > + return endpoint; > + > +err_endpoint: > + device_unlock(&endpoint->dev); > +err: > + device_unlock(&cxlmd->dev); > + return ERR_PTR(rc); > +} > +EXPORT_SYMBOL_NS(cxl_acquire_endpoint, CXL); > + > +void cxl_release_endpoint(struct cxl_memdev *cxlmd, struct cxl_port *endpoint) > +{ > + device_unlock(&endpoint->dev); > + device_unlock(&cxlmd->dev); > +} > +EXPORT_SYMBOL_NS(cxl_release_endpoint, CXL); > + > static void sanitize_teardown_notifier(void *data) > { > struct cxl_memdev_state *mds = data; > diff --git a/include/linux/cxlmem.h b/include/linux/cxlmem.h > index e8d12b543db1..11fe8367b046 100644 > --- a/include/linux/cxlmem.h > +++ b/include/linux/cxlmem.h > @@ -88,6 +88,10 @@ static inline bool is_cxl_endpoint(struct cxl_port *port) > > struct cxl_memdev *devm_cxl_add_memdev(struct device *host, > struct cxl_dev_state *cxlds); > + > +struct cxl_port *cxl_acquire_endpoint(struct cxl_memdev *cxlmd); > +void cxl_release_endpoint(struct cxl_memdev *cxlmd, struct cxl_port *endpoint); > + > int devm_cxl_sanitize_setup_notifier(struct device *host, > struct cxl_memdev *cxlmd); > struct cxl_memdev_state; > diff --git a/tools/testing/cxl/type2/pci_type2.c b/tools/testing/cxl/type2/pci_type2.c > index f157139b712f..948cc95c5780 100644 > --- a/tools/testing/cxl/type2/pci_type2.c > +++ b/tools/testing/cxl/type2/pci_type2.c > @@ -6,6 +6,7 @@ > > struct cxl_dev_state *cxlds; > struct cxl_memdev *cxlmd; > +struct cxl_port *endpoint; > > #define CXL_TYPE2_MEM_SIZE (1024*1024*256) > > @@ -72,6 +73,14 @@ static int type2_pci_probe(struct pci_dev *pci_dev, > if (IS_ERR(cxlmd)) > return PTR_ERR(cxlmd); > > + endpoint = cxl_acquire_endpoint(cxlmd); > + if (IS_ERR(endpoint)) { > + dev_dbg(&pci_dev->dev, "cxl_acquire_endpoint failed\n"); > + return PTR_ERR(endpoint); > + } > + > + cxl_release_endpoint(cxlmd, endpoint); > + > return 0; > } >
diff --git a/drivers/cxl/core/memdev.c b/drivers/cxl/core/memdev.c index 27063cd4ea73..16e356ef5b6d 100644 --- a/drivers/cxl/core/memdev.c +++ b/drivers/cxl/core/memdev.c @@ -1124,6 +1124,47 @@ struct cxl_memdev *devm_cxl_add_memdev(struct device *host, } EXPORT_SYMBOL_NS_GPL(devm_cxl_add_memdev, CXL); +/* + * Try to get a locked reference on a memdev's CXL port topology + * connection. Be careful to observe when cxl_mem_probe() has deposited + * a probe deferral awaiting the arrival of the CXL root driver +*/ +struct cxl_port *cxl_acquire_endpoint(struct cxl_memdev *cxlmd) +{ + struct cxl_port *endpoint; + int rc = -ENXIO; + + device_lock(&cxlmd->dev); + endpoint = cxlmd->endpoint; + if (!endpoint) + goto err; + + if (IS_ERR(endpoint)) { + rc = PTR_ERR(endpoint); + goto err; + } + + device_lock(&endpoint->dev); + if (!endpoint->dev.driver) + goto err_endpoint; + + return endpoint; + +err_endpoint: + device_unlock(&endpoint->dev); +err: + device_unlock(&cxlmd->dev); + return ERR_PTR(rc); +} +EXPORT_SYMBOL_NS(cxl_acquire_endpoint, CXL); + +void cxl_release_endpoint(struct cxl_memdev *cxlmd, struct cxl_port *endpoint) +{ + device_unlock(&endpoint->dev); + device_unlock(&cxlmd->dev); +} +EXPORT_SYMBOL_NS(cxl_release_endpoint, CXL); + static void sanitize_teardown_notifier(void *data) { struct cxl_memdev_state *mds = data; diff --git a/include/linux/cxlmem.h b/include/linux/cxlmem.h index e8d12b543db1..11fe8367b046 100644 --- a/include/linux/cxlmem.h +++ b/include/linux/cxlmem.h @@ -88,6 +88,10 @@ static inline bool is_cxl_endpoint(struct cxl_port *port) struct cxl_memdev *devm_cxl_add_memdev(struct device *host, struct cxl_dev_state *cxlds); + +struct cxl_port *cxl_acquire_endpoint(struct cxl_memdev *cxlmd); +void cxl_release_endpoint(struct cxl_memdev *cxlmd, struct cxl_port *endpoint); + int devm_cxl_sanitize_setup_notifier(struct device *host, struct cxl_memdev *cxlmd); struct cxl_memdev_state; diff --git a/tools/testing/cxl/type2/pci_type2.c b/tools/testing/cxl/type2/pci_type2.c index f157139b712f..948cc95c5780 100644 --- a/tools/testing/cxl/type2/pci_type2.c +++ b/tools/testing/cxl/type2/pci_type2.c @@ -6,6 +6,7 @@ struct cxl_dev_state *cxlds; struct cxl_memdev *cxlmd; +struct cxl_port *endpoint; #define CXL_TYPE2_MEM_SIZE (1024*1024*256) @@ -72,6 +73,14 @@ static int type2_pci_probe(struct pci_dev *pci_dev, if (IS_ERR(cxlmd)) return PTR_ERR(cxlmd); + endpoint = cxl_acquire_endpoint(cxlmd); + if (IS_ERR(endpoint)) { + dev_dbg(&pci_dev->dev, "cxl_acquire_endpoint failed\n"); + return PTR_ERR(endpoint); + } + + cxl_release_endpoint(cxlmd, endpoint); + return 0; }