diff mbox series

[11/11] samples/devsec: Add sample IDE establishment

Message ID 173343745958.1074769.12896997365766327404.stgit@dwillia2-xfh.jf.intel.com (mailing list archive)
State New
Headers show
Series PCI/TSM: Core infrastructure for PCI device security (TDISP) | expand

Commit Message

Dan Williams Dec. 5, 2024, 10:24 p.m. UTC
Exercise common setup and teardown flows for a sample platform TSM
driver that implements the TSM 'connect' and 'disconnect' flows.

This is both a template for platform specific implementations and a test
case for the shared infrastructure.

Cc: Bjorn Helgaas <bhelgaas@google.com>
Cc: Lukas Wunner <lukas@wunner.de>
Cc: Samuel Ortiz <sameo@rivosinc.com>
Cc: Alexey Kardashevskiy <aik@amd.com>
Cc: Xu Yilun <yilun.xu@linux.intel.com>
Signed-off-by: Dan Williams <dan.j.williams@intel.com>
---
 samples/devsec/tsm.c |   85 ++++++++++++++++++++++++++++++++++++++++++++++++--
 1 file changed, 82 insertions(+), 3 deletions(-)
diff mbox series

Patch

diff --git a/samples/devsec/tsm.c b/samples/devsec/tsm.c
index d446ab8879d8..a8894d08f323 100644
--- a/samples/devsec/tsm.c
+++ b/samples/devsec/tsm.c
@@ -4,11 +4,14 @@ 
 #define dev_fmt(fmt) "devsec: " fmt
 #include <linux/platform_device.h>
 #include <linux/pci-tsm.h>
+#include <linux/pci-ide.h>
 #include <linux/module.h>
 #include <linux/pci.h>
 #include <linux/tsm.h>
 #include "devsec.h"
 
+#define DEVSEC_NR_IDE_STREAMS 4
+
 struct devsec_dsm {
 	struct pci_dsm pci;
 };
@@ -42,13 +45,60 @@  static void devsec_tsm_pci_remove(struct pci_dsm *dsm)
 	kfree(devsec_dsm);
 }
 
+/* protected by tsm_ops lock */
+static DECLARE_BITMAP(devsec_stream_ids, DEVSEC_NR_IDE_STREAMS);
+static struct devsec_stream_info {
+	struct pci_dev *pdev;
+	struct pci_ide ide;
+} devsec_streams[DEVSEC_NR_IDE_STREAMS];
+
 static int devsec_tsm_connect(struct pci_dev *pdev)
 {
-	return -ENXIO;
+	struct pci_ide *ide;
+	int rc, stream_id;
+
+	stream_id =
+		find_first_zero_bit(devsec_stream_ids, DEVSEC_NR_IDE_STREAMS);
+	if (stream_id == DEVSEC_NR_IDE_STREAMS)
+		return -EBUSY;
+	set_bit(stream_id, devsec_stream_ids);
+	ide = &devsec_streams[stream_id].ide;
+	pci_ide_stream_probe(pdev, ide);
+
+	ide->stream_id = stream_id;
+	rc = pci_ide_stream_setup(pdev, ide, PCI_IDE_SETUP_ROOT_PORT);
+	if (rc)
+		return rc;
+	rc = tsm_register_ide_stream(pdev, ide);
+	if (rc)
+		goto err;
+
+	devsec_streams[stream_id].pdev = pdev;
+	pci_ide_enable_stream(pdev, ide);
+	return 0;
+err:
+	pci_ide_stream_teardown(pdev, ide, PCI_IDE_SETUP_ROOT_PORT);
+	return rc;
 }
 
 static void devsec_tsm_disconnect(struct pci_dev *pdev)
 {
+	struct pci_ide *ide;
+	int i;
+
+	for_each_set_bit(i, devsec_stream_ids, DEVSEC_NR_IDE_STREAMS)
+		if (devsec_streams[i].pdev == pdev)
+			break;
+
+	if (i >= DEVSEC_NR_IDE_STREAMS)
+		return;
+
+	ide = &devsec_streams[i].ide;
+	pci_ide_disable_stream(pdev, ide);
+	tsm_unregister_ide_stream(ide);
+	pci_ide_stream_teardown(pdev, ide, PCI_IDE_SETUP_ROOT_PORT);
+	devsec_streams[i].pdev = NULL;
+	clear_bit(i, devsec_stream_ids);
 }
 
 static const struct pci_tsm_ops devsec_pci_ops = {
@@ -63,16 +113,44 @@  static void devsec_tsm_remove(void *tsm)
 	tsm_unregister(tsm);
 }
 
+static void set_nr_ide_streams(int nr)
+{
+	struct pci_dev *pdev = NULL;
+
+	for_each_pci_dev(pdev) {
+		struct pci_host_bridge *hb;
+
+		if (pdev->sysdata != devsec_sysdata)
+			continue;
+		hb = pci_find_host_bridge(pdev->bus);
+		if (hb->nr_ide_streams >= 0)
+			continue;
+		pci_set_nr_ide_streams(hb, nr);
+	}
+}
+
+static void devsec_tsm_ide_teardown(void *data)
+{
+	set_nr_ide_streams(-1);
+}
+
 static int devsec_tsm_probe(struct platform_device *pdev)
 {
 	struct tsm_subsys *tsm;
+	int rc;
 
 	tsm = tsm_register(&pdev->dev, NULL, &devsec_pci_ops);
 	if (IS_ERR(tsm))
 		return PTR_ERR(tsm);
 
-	return devm_add_action_or_reset(&pdev->dev, devsec_tsm_remove,
-					tsm);
+	rc = devm_add_action_or_reset(&pdev->dev, devsec_tsm_remove, tsm);
+	if (rc)
+		return rc;
+
+	set_nr_ide_streams(DEVSEC_NR_IDE_STREAMS);
+
+	return devm_add_action_or_reset(&pdev->dev, devsec_tsm_ide_teardown,
+					NULL);
 }
 
 static struct platform_driver devsec_tsm_driver = {
@@ -109,5 +187,6 @@  static void __exit devsec_tsm_exit(void)
 }
 module_exit(devsec_tsm_exit);
 
+MODULE_IMPORT_NS(PCI_IDE);
 MODULE_LICENSE("GPL");
 MODULE_DESCRIPTION("Device Security Sample Infrastructure: Platform TSM Driver");