diff mbox

KVM test: Add PCI pass through test

Message ID 1255522080-10570-1-git-send-email-lmr@redhat.com (mailing list archive)
State New, archived
Headers show

Commit Message

Lucas Meneghel Rodrigues Oct. 14, 2009, 12:08 p.m. UTC
None
diff mbox

Patch

diff --git a/client/tests/kvm/kvm_tests.cfg.sample b/client/tests/kvm/kvm_tests.cfg.sample
index cc3228a..1dad188 100644
--- a/client/tests/kvm/kvm_tests.cfg.sample
+++ b/client/tests/kvm/kvm_tests.cfg.sample
@@ -786,13 +786,22 @@  variants:
         only default
         image_format = raw
 
-
 variants:
     - @smallpages:
     - hugepages:
         pre_command = "/usr/bin/python scripts/hugepage.py /mnt/kvm_hugepage"
         extra_params += " -mem-path /mnt/kvm_hugepage"
 
+variants:
+    - @no_passthrough:
+        pass_through = no
+    - nic_passthrough:
+        pass_through = pf
+        passthrough_devs = eth1
+    - vfs_passthrough:
+        pass_through = vf
+        max_vfs = 7
+        vfs_count = 7
 
 variants:
     - @basic:
diff --git a/client/tests/kvm/kvm_utils.py b/client/tests/kvm/kvm_utils.py
index 53b664a..0e3398c 100644
--- a/client/tests/kvm/kvm_utils.py
+++ b/client/tests/kvm/kvm_utils.py
@@ -788,3 +788,281 @@  def md5sum_file(filename, size=None):
         size -= len(data)
     f.close()
     return o.hexdigest()
+
+
+def get_full_id(pci_id):
+    """
+    Get full PCI ID of pci_id.
+    """
+    cmd = "lspci -D | awk '/%s/ {print $1}'" % pci_id
+    status, full_id = commands.getstatusoutput(cmd)
+    if status != 0:
+        return None
+    return full_id
+
+
+def get_vendor_id(pci_id):
+    """
+    Check out the device vendor ID according to PCI ID.
+    """
+    cmd = "lspci -n | awk '/%s/ {print $3}'" % pci_id
+    return re.sub(":", " ", commands.getoutput(cmd))
+
+
+def release_dev(pci_id, pci_dict):
+    """
+    Release a single PCI device.
+
+    @param pci_id: PCI ID of a given PCI device
+    @param pci_dict: Dictionary with information about PCI devices
+    """
+    base_dir = "/sys/bus/pci"
+    full_id = get_full_id(pci_id)
+    vendor_id = get_vendor_id(pci_id)
+    drv_path = os.path.join(base_dir, "devices/%s/driver" % full_id)
+    if 'pci-stub' in os.readlink(drv_path):
+        cmd = "echo '%s' > %s/new_id" % (vendor_id, drv_path)
+        if os.system(cmd):
+            return False
+
+        stub_path = os.path.join(base_dir, "drivers/pci-stub")
+        cmd = "echo '%s' > %s/unbind" % (full_id, stub_path)
+        if os.system(cmd):
+            return False
+
+        prev_driver = pci_dict[pci_id]
+        cmd = "echo '%s' > %s/bind" % (full_id, prev_driver)
+        if os.system(cmd):
+            return False
+    return True
+
+
+def release_pci_devs(pci_dict):
+    """
+    Release all PCI devices assigned to host.
+
+    @param pci_dict: Dictionary with information about PCI devices
+    """
+    for pci_id in pci_dict:
+        if not release_dev(pci_id, pci_dict):
+            logging.error("Failed to release device [%s] to host" % pci_id)
+        else:
+            logging.info("Release device [%s] successfully" % pci_id)
+
+
+class PassThrough(object):
+    """
+    Request passthroughable devices on host. It will check whether to request
+    PF (physical NIC cards) or VF (Virtual Functions).
+    """
+    def __init__(self, type="nic_vf", max_vfs=None, names=None):
+        """
+        Initialize parameter 'type' which could be:
+        nic_vf: Virtual Functions
+        nic_pf: Physical NIC card
+        mixed:  Both includes VFs and PFs
+
+        If pass through Physical NIC cards, we need to specify which devices
+        to be assigned, e.g. 'eth1 eth2'.
+
+        If pass through Virtual Functions, we need to specify how many vfs
+        are going to be assigned, e.g. passthrough_count = 8 and max_vfs in
+        config file.
+
+        @param type: Pass through device's type
+        @param max_vfs: parameter of module 'igb'
+        @param names: Physical NIC cards' names, e.g.'eth1 eth2 ...'
+        """
+        self.type = type
+        if max_vfs:
+            self.max_vfs = int(max_vfs)
+        if names:
+            self.name_list = names.split()
+
+
+    def sr_iov_setup(self):
+        """
+        Setup SR-IOV environment, check if module 'igb' is loaded with
+        parameter 'max_vfs'.
+        """
+        re_probe = False
+        # Check whether the module 'igb' is loaded
+        s, o = commands.getstatusoutput('lsmod | grep igb')
+        if s:
+            re_probe = True
+        elif not self.chk_vfs_count():
+            os.system("modprobe -r igb")
+            re_probe = True
+
+        # Re-probe module 'igb'
+        if re_probe:
+            cmd = "modprobe igb max_vfs=%d" % self.max_vfs
+            s, o = commands.getstatusoutput(cmd)
+            if s:
+                return False
+            if not self.chk_vfs_count():
+                return False
+        return True
+
+
+    def get_vfs_count(self):
+        """
+        Get Vfs count number according to 'max_vfs'
+        """
+        cmd = "lspci | grep 'Virtual Function' | wc -l"
+        return int(commands.getoutput(cmd))
+
+
+    def chk_vfs_count(self):
+        """
+        Check VFs count number according to parameter 'max_vfs'.
+        """
+        if self.get_vfs_count() != (self.max_vfs * 2):
+            return False
+        return True
+
+
+    def _get_pci_id(self, name, search_str):
+        """
+        Get the device's PCI ID according to name.
+        """
+        cmd = "ethtool -i %s | awk '/bus-info/ {print $2}'" % name
+        s, pci_id = commands.getstatusoutput(cmd)
+        if not (s or "Cannot get driver information" in pci_id):
+            return pci_id[5:]
+
+        cmd = "lspci | awk '/%s/ {print $1}'" % search_str
+        pci_ids = [id for id in commands.getoutput(cmd).splitlines()]
+        nic_id = int(re.search('[0-9]+', name).group(0))
+        if (len(pci_ids) - 1) < nic_id:
+            return None
+        return pci_ids[nic_id]
+
+
+    def _write_to_file(self, content, file_path):
+        """
+        Write some content to a file.
+
+        @param content: Content to be written.
+        @param file_path: Path for the file.
+        """
+        success = False
+        try:
+            file = open(file_path, 'w')
+        except IOError:
+            return success
+
+        try:
+            file.write(content)
+            success = True
+        except IOError:
+            pass
+        finally:
+            file.close()
+
+        return success
+
+
+    def is_binded_to_stub(self, full_id):
+        """
+        Verify whether the device with full_id is already binded to pci-stub.
+        """
+        base_dir = "/sys/bus/pci"
+        stub_path = os.path.join(base_dir, "drivers/pci-stub")
+        if os.path.exists(os.path.join(stub_path, full_id)):
+            return True
+        return False
+
+
+    def request_devs(self, count, pt_file, setup=True):
+        """
+        Implement setup process: unbind the PCI device and then bind it
+        to pci-stub driver.
+
+        @return: a list of successfully requested devices' PCI IDs.
+        """
+        base_dir = "/sys/bus/pci"
+        stub_path = os.path.join(base_dir, "drivers/pci-stub")
+
+        self.pci_ids = self.catch_devs(int(count))
+        logging.debug("Catch_devs has found pci_ids: %s" % self.pci_ids)
+
+        if setup:
+            requested_pci_ids = []
+            logging.debug("Going to setup devices: %s" % self.pci_ids)
+            dev_prev_drivers = {}
+
+            # Setup all devices specified for passthrough to guest
+            for pci_id in self.pci_ids:
+                full_id = get_full_id(pci_id)
+                if not full_id:
+                    continue
+                drv_path = os.path.join(base_dir, "devices/%s/driver" % full_id)
+                dev_prev_driver= os.path.realpath(os.path.join(drv_path,
+                                                  os.readlink(drv_path)))
+                dev_prev_drivers[pci_id] = dev_prev_driver
+
+                # Judge whether the device driver has been binded to stub
+                if not self.is_binded_to_stub(full_id):
+                    logging.debug("Setting up device: %s" % full_id)
+                    stub_new_id = os.path.join(stub_path, 'new_id')
+                    unbind_dev = os.path.join(drv_path, 'unbind')
+                    stub_bind = os.path.join(stub_path, 'bind')
+
+                    vendor_id = get_vendor_id(pci_id)
+                    if not self._write_to_file(vendor_id, stub_new_id):
+                        continue
+                    if not self._write_to_file(full_id, unbind_dev):
+                        continue
+                    if not self._write_to_file(full_id, stub_bind):
+                        continue
+                    if not self.is_binded_to_stub(full_id):
+                        logging.error("Request device failed: %s" % pci_id)
+                        continue
+                else:
+                    logging.debug("Device [%s] already binded to stub" % pci_id)
+                requested_pci_ids.append(pci_id)
+            cPickle.dump(dev_prev_drivers, pt_file)
+            self.pci_ids = requested_pci_ids
+        return self.pci_ids
+
+
+    def catch_vf_devs(self):
+        """
+        Catch all VFs PCI IDs and return as a list.
+        """
+        if not self.sr_iov_setup():
+            return []
+
+        cmd = "lspci | awk '/Virtual Function/ {print $1}'"
+        return commands.getoutput(cmd).split()
+
+
+    def catch_pf_devs(self):
+        """
+        Catch all PFs PCI IDs.
+        """
+        pf_ids = []
+        for name in self.name_list:
+            pf_id = self._get_pci_id(name, "Ethernet")
+            if not pf_id:
+                continue
+            pf_ids.append(pf_id)
+        return pf_ids
+
+
+    def catch_devs(self, count):
+        """
+        Check out all devices' PCI IDs according to their name.
+
+        @param count: count number of PCI devices needed for pass through
+        @return: a list of all devices' PCI IDs
+        """
+        if self.type == "nic_vf":
+            vf_ids = self.catch_vf_devs()
+        elif self.type == "nic_pf":
+            vf_ids = self.catch_pf_devs()
+        elif self.type == "mixed":
+            vf_ids = self.catch_vf_devs()
+            vf_ids.extend(self.catch_pf_devs())
+        return vf_ids[0:count]
diff --git a/client/tests/kvm/kvm_vm.py b/client/tests/kvm/kvm_vm.py
index 82f1eb4..a0b01fe 100755
--- a/client/tests/kvm/kvm_vm.py
+++ b/client/tests/kvm/kvm_vm.py
@@ -298,6 +298,30 @@  class VM:
         elif params.get("uuid"):
             qemu_cmd += " -uuid %s" % params.get("uuid")
 
+        # Check whether pci_passthrough is turned on
+        if not params.get("pass_through") == "no":
+            pt_type = params.get("pass_through")
+
+            # Virtual Functions pass through
+            if params.get("pass_through") == "nic_vf":
+                pt = kvm_utils.PassThrough(type=pt_type,
+                                           max_vfs=params.get("max_vfs"))
+            # Physical NIC card pass through
+            elif params.get("pass_through") == "nic_pf":
+                pt = kvm_utils.PassThrough(type=pt_type,
+                                           names=params.get("passthrough_devs"))
+            elif params.get("pass_through") == "mixed":
+                pt = kvm_utils.PassThrough(type=pt_type,
+                                           max_vfs=params.get("max_vfs"),
+                                           names=params.get("passthrough_devs"))
+            pt_count = params.get("passthrough_count")
+            pt_db_file = open("/tmp/passthrough_db", "w")
+            self.pt_pci_ids = pt.request_devs(pt_count, pt_file=pt_db_file,
+                                              setup=False)
+            pt_db_file.close()
+            for pci_id in self.pt_pci_ids:
+                qemu_cmd += " -pcidevice host=%s" % pci_id
+
         return qemu_cmd
 
 
@@ -380,6 +404,37 @@  class VM:
                 self.uuid = f.read().strip()
                 f.close()
 
+            # Request specified PCI devices for assignment
+            if not params.get("pass_through") == "no":
+                pt_type = params.get("pass_through")
+
+                # Virtual Functions pass through
+                if params.get("pass_through") == "nic_vf":
+                    logging.debug("Going to assign VFs to guest")
+                    pt = kvm_utils.PassThrough(type=pt_type,
+                                               max_vfs = params.get("max_vfs"))
+                # Physical NIC card pass through
+                elif params.get("pass_through") == "nic_pf":
+                    logging.debug("Going to assign physical NICs to guest")
+                    pt = kvm_utils.PassThrough(type=pt_type,
+                                           names=params.get("passthrough_devs"))
+                elif params.get("pass_through") == "mixed":
+                    logging.debug("Assigned NICs may contain both VFs and PFs")
+                    pt = kvm_utils.PassThrough(type = pt_type,
+                                               max_vfs = params.get("max_vfs"),
+                                         names = params.get("passthrough_devs"))
+                pt_count = params.get("passthrough_count")
+                pt_db_file = open('/tmp/passthrough_db', 'w')
+                self.pt_pci_ids = pt.request_devs(pt_count, pt_file=pt_db_file,
+                                                  setup=True)
+                pt_db_file.close()
+                if not self.pt_pci_ids:
+                    logging.error("Request specified PCI device(s) failed;"
+                                  " pt_pci_ids are: %s", self.pt_pci_ids)
+                    return False
+                logging.debug("Requested host PCI device(s): %s",
+                              self.pt_pci_ids)
+
             # Make qemu command
             qemu_command = self.make_qemu_command()
 
@@ -504,6 +559,19 @@  class VM:
         return (0, data)
 
 
+    def _release_pt_devs(self):
+        """
+        Release the assigned passthrough PCI devices or VFs to host.
+        """
+        try:
+            db_file = open('/tmp/passthrough_db', 'r')
+            pt_db = cPickle.load(db_file)
+            kvm_utils.release_pci_devs(pt_db)
+            db_file.close()
+        except:
+            return
+
+
     def destroy(self, gracefully=True):
         """
         Destroy the VM.
@@ -520,6 +588,7 @@  class VM:
             # Is it already dead?
             if self.is_dead():
                 logging.debug("VM is already down")
+                self._release_pt_devs()
                 return
 
             logging.debug("Destroying VM with PID %d..." %
@@ -540,6 +609,7 @@  class VM:
                             return
                     finally:
                         session.close()
+                        self._release_pt_devs()
 
             # Try to destroy with a monitor command
             logging.debug("Trying to kill VM with monitor command...")
@@ -549,6 +619,7 @@  class VM:
                 # Wait for the VM to be really dead
                 if kvm_utils.wait_for(self.is_dead, 5, 0.5, 0.5):
                     logging.debug("VM is down")
+                    self._release_pt_devs()
                     return
 
             # If the VM isn't dead yet...
@@ -558,6 +629,7 @@  class VM:
             # Wait for the VM to be really dead
             if kvm_utils.wait_for(self.is_dead, 5, 0.5, 0.5):
                 logging.debug("VM is down")
+                self._release_pt_devs()
                 return
 
             logging.error("Process %s is a zombie!" % self.process.get_pid())