new file mode 100644
@@ -0,0 +1,412 @@
+#!/usr/bin/python
+"""
+Client for file transfer services offered by RSS (Remote Shell Server).
+
+@author: Michael Goldish (mgoldish@redhat.com)
+@copyright: 2008-2010 Red Hat Inc.
+"""
+
+import socket, struct, time, sys, os, glob
+
+# Globals
+BUFSIZE = 65536
+
+# Protocol message constants
+RSS_MAGIC = 0x525353
+RSS_OK = 1
+RSS_ERROR = 2
+RSS_UPLOAD = 3
+RSS_DOWNLOAD = 4
+RSS_SET_PATH = 5
+RSS_CREATE_FILE = 6
+RSS_CREATE_DIR = 7
+RSS_LEAVE_DIR = 8
+RSS_DONE = 9
+
+# See rss.cpp for protocol details.
+
+
+class FileTransferError(Exception):
+ pass
+
+
+class FileTransferConnectError(FileTransferError):
+ pass
+
+
+class FileTransferTimeoutError(FileTransferError):
+ pass
+
+
+class FileTransferProtocolError(FileTransferError):
+ pass
+
+
+class FileTransferSendError(FileTransferError):
+ pass
+
+
+class FileTransferServerError(FileTransferError):
+ pass
+
+
+class FileTransferClient(object):
+ """
+ Connect to a RSS (remote shell server) and transfer files.
+ """
+
+ def __init__(self, address, port, timeout=10):
+ """
+ Connect to a server.
+
+ @param address: The server's address
+ @param port: The server's port
+ @param timeout: Time duration to wait for connection to succeed
+ @raise FileTransferConnectError: Raised if the connection fails
+ @raise FileTransferProtocolError: Raised if an incorrect magic number
+ is received
+ """
+ self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self._socket.settimeout(timeout)
+ try:
+ self._socket.connect((address, port))
+ except socket.error:
+ raise FileTransferConnectError("Could not connect to server")
+ if self._receive_msg(timeout) != RSS_MAGIC:
+ raise FileTransferProtocolError("Received wrong magic number")
+
+
+ def __del__(self):
+ self.close()
+
+
+ def close(self):
+ """
+ Close the connection.
+ """
+ self._socket.close()
+
+
+ def _send(self, str):
+ try:
+ self._socket.sendall(str)
+ except socket.error:
+ raise FileTransferSendError("Could not send data to server")
+
+
+ def _receive(self, size, timeout=10):
+ s = ""
+ end_time = time.time() + timeout
+ while size:
+ try:
+ self._socket.settimeout(max(0, end_time - time.time()))
+ data = self._socket.recv(min(BUFSIZE, size))
+ except socket.timeout:
+ raise FileTransferTimeoutError("Timeout expired while "
+ "receiving data from server")
+ except socket.error:
+ raise FileTransferProtocolError("Error receiving data from "
+ "server")
+ if not data:
+ raise FileTransferProtocolError("Connection closed "
+ "unexpectedly")
+ s += data
+ size -= len(data)
+ return s
+
+
+ def _send_packet(self, str):
+ self._send(struct.pack("=Q", len(str)))
+ self._send(str)
+
+
+ def _receive_packet(self, timeout=10):
+ size = struct.unpack("=Q", self._receive(8))[0]
+ return self._receive(size, timeout)
+
+
+ def _send_packet_from_file(self, filename, timeout=30):
+ self._send(struct.pack("=Q", os.path.getsize(filename)))
+ f = open(filename, "rb")
+ try:
+ end_time = time.time() + timeout
+ while time.time() < end_time:
+ data = f.read(BUFSIZE)
+ if not data:
+ break
+ self._send(data)
+ else:
+ raise FileTransferTimeoutError("Timeout expired while sending "
+ "file %s" % filename)
+ finally:
+ f.close()
+
+
+ def _receive_packet_into_file(self, filename, timeout=30):
+ size = struct.unpack("=Q", self._receive(8))[0]
+ f = open(filename, "wb")
+ try:
+ end_time = time.time() + timeout
+ while size:
+ try:
+ self._socket.settimeout(max(0, end_time - time.time()))
+ data = self._socket.recv(min(BUFSIZE, size))
+ except socket.timeout:
+ raise FileTransferTimeoutError("Timeout expired while "
+ "receiving file %s" %
+ filename)
+ except socket.error:
+ raise FileTransferProtocolError("Error receiving file %s" %
+ filename)
+ if not data:
+ raise FileTransferProtocolError("Connection closed "
+ "unexpectedly while "
+ "receiving file %s" %
+ filename)
+ f.write(data)
+ size -= len(data)
+ finally:
+ f.close()
+
+
+ def _send_msg(self, msg, timeout=10):
+ self._send(struct.pack("=i", msg))
+
+
+ def _receive_msg(self, timeout=10):
+ s = self._receive(4, timeout)
+ return struct.unpack("=i", s)[0]
+
+
+ def _handle_transfer_error(self):
+ # Save original exception
+ e = sys.exc_info()
+ try:
+ # See if we can get an error message
+ msg = self._receive_msg()
+ except FileTransferError:
+ # No error message -- re-raise original exception
+ raise e[0], e[1], e[2]
+ if msg == RSS_ERROR:
+ errmsg = self._receive_packet()
+ raise FileTransferServerError("Server said: %s" % errmsg)
+ raise e[0], e[1], e[2]
+
+
+class FileUploadClient(FileTransferClient):
+ """
+ Connect to a RSS (remote shell server) and upload files or directory trees.
+ """
+
+ def __init__(self, address, port, timeout=10):
+ """
+ Connect to a server.
+
+ @param address: The server's address
+ @param port: The server's port
+ @param timeout: Time duration to wait for connection to succeed
+ @raise FileTransferConnectError: Raised if the connection fails
+ @raise FileTransferProtocolError: Raised if an incorrect magic number
+ is received
+ @raise FileTransferSendError: Raised if the RSS_UPLOAD message cannot
+ be sent to the server
+ """
+ super(FileUploadClient, self).__init__(address, port, timeout)
+ self._send_msg(RSS_UPLOAD)
+
+
+ def _upload_file(self, path, end_time):
+ if os.path.isfile(path):
+ self._send_msg(RSS_CREATE_FILE)
+ self._send_packet(os.path.basename(path))
+ self._send_packet_from_file(path, max(0, end_time - time.time()))
+ elif os.path.isdir(path):
+ self._send_msg(RSS_CREATE_DIR)
+ self._send_packet(os.path.basename(path))
+ for filename in os.listdir(path):
+ self._upload_file(os.path.join(path, filename), end_time)
+ self._send_msg(RSS_LEAVE_DIR)
+
+
+ def upload(self, src_pattern, dst_path, timeout=600):
+ """
+ Send files or directory trees to the server.
+ The semantics of src_pattern and dst_path are similar to those of scp.
+ For example, the following are OK:
+ src_pattern='/tmp/foo.txt', dst_path='C:\\'
+ (copies a single file)
+ src_pattern='/usr/', dst_path='C:\\Windows\\'
+ (copies a directory tree recursively)
+ src_pattern='/usr/*', dst_path='C:\\Windows\\'
+ (copies all files and directory trees under /usr/)
+ The following is not OK:
+ src_pattern='/tmp/foo.txt', dst_path='C:\\Windows\\*'
+ (wildcards are only allowed in src_pattern)
+
+ @param src_pattern: A path or wildcard pattern specifying the files or
+ directories to send to the server
+ @param dst_path: A path in the server's filesystem where the files will
+ be saved
+ @param timeout: Time duration in seconds to wait for the transfer to
+ complete
+ @raise FileTransferTimeoutError: Raised if timeout expires
+ @raise FileTransferServerError: Raised if something goes wrong and the
+ server sends an informative error message to the client
+ @note: Other exceptions can be raised.
+ """
+ end_time = time.time() + timeout
+ try:
+ try:
+ self._send_msg(RSS_SET_PATH)
+ self._send_packet(dst_path)
+ for filename in glob.iglob(src_pattern):
+ self._upload_file(os.path.abspath(filename), end_time)
+ self._send_msg(RSS_DONE)
+ except FileTransferTimeoutError:
+ raise
+ except FileTransferError:
+ self._handle_transfer_error()
+ else:
+ # Look for RSS_OK or RSS_ERROR
+ msg = self._receive_msg(max(0, end_time - time.time()))
+ if msg == RSS_OK:
+ return
+ elif msg == RSS_ERROR:
+ errmsg = self._receive_packet()
+ raise FileTransferServerError("Server said: %s" % errmsg)
+ else:
+ # Neither RSS_OK nor RSS_ERROR found
+ raise FileTransferProtocolError
+ except:
+ # In any case, if the transfer failed, close the connection
+ self.close()
+ raise
+
+
+class FileDownloadClient(FileTransferClient):
+ """
+ Connect to a RSS (remote shell server) and download files or directory trees.
+ """
+
+ def __init__(self, address, port, timeout=10):
+ """
+ Connect to a server.
+
+ @param address: The server's address
+ @param port: The server's port
+ @param timeout: Time duration to wait for connection to succeed
+ @raise FileTransferConnectError: Raised if the connection fails
+ @raise FileTransferProtocolError: Raised if an incorrect magic number
+ is received
+ @raise FileTransferSendError: Raised if the RSS_UPLOAD message cannot
+ be sent to the server
+ """
+ super(FileDownloadClient, self).__init__(address, port, timeout)
+ self._send_msg(RSS_DOWNLOAD)
+
+
+ def download(self, src_pattern, dst_path, timeout=600):
+ """
+ Receive files or directory trees from the server.
+ The semantics of src_pattern and dst_path are similar to those of scp.
+ For example, the following are OK:
+ src_pattern='C:\\foo.txt', dst_path='/tmp'
+ (receives a single file)
+ src_pattern='C:\\Windows', dst_path='/tmp'
+ (receives a directory tree recursively)
+ src_pattern='C:\\Windows\\*', dst_path='/tmp'
+ (receives all files and directory trees under C:\\Windows)
+ The following is not OK:
+ src_pattern='C:\\Windows', dst_path='/tmp/*'
+ (wildcards are only allowed in src_pattern)
+
+ @param src_pattern: A path or wildcard pattern specifying the files or
+ directories, in the server's filesystem, that will be sent to
+ the client
+ @param dst_path: A path in the local filesystem where the files will
+ be saved
+ @param timeout: Time duration in seconds to wait for the transfer to
+ complete
+ @raise FileTransferTimeoutError: Raised if timeout expires
+ @raise FileTransferServerError: Raised if something goes wrong and the
+ server sends an informative error message to the client
+ @note: Other exceptions can be raised.
+ """
+ dst_path = os.path.abspath(dst_path)
+ end_time = time.time() + timeout
+ try:
+ try:
+ self._send_msg(RSS_SET_PATH)
+ self._send_packet(src_pattern)
+ except FileTransferError:
+ self._handle_transfer_error()
+ while True:
+ msg = self._receive_msg()
+ if msg == RSS_CREATE_FILE:
+ # Receive filename and file contents
+ filename = self._receive_packet()
+ if os.path.isdir(dst_path):
+ dst_path = os.path.join(dst_path, filename)
+ self._receive_packet_into_file(
+ dst_path, max(0, end_time - time.time()))
+ dst_path = os.path.dirname(dst_path)
+ elif msg == RSS_CREATE_DIR:
+ # Receive dirname and create the directory
+ dirname = self._receive_packet()
+ if os.path.isdir(dst_path):
+ dst_path = os.path.join(dst_path, dirname)
+ if not os.path.isdir(dst_path):
+ os.mkdir(dst_path)
+ elif msg == RSS_LEAVE_DIR:
+ # Return to parent dir
+ dst_path = os.path.dirname(dst_path)
+ elif msg == RSS_DONE:
+ # Transfer complete
+ break
+ elif msg == RSS_ERROR:
+ # Receive error message and abort
+ errmsg = self._receive_packet()
+ raise FileTransferServerError("Server said: %s" % errmsg)
+ else:
+ # Unexpected msg
+ raise FileTransferProtocolError("Received unexpected msg")
+ except:
+ # In any case, if the transfer failed, close the connection
+ self.close()
+ raise
+
+
+def main():
+ import optparse
+
+ usage = "usage: %prog [options] address port src_pattern dst_path"
+ parser = optparse.OptionParser(usage=usage)
+ parser.add_option("-d", "--download",
+ action="store_true", dest="download",
+ help="download files from server")
+ parser.add_option("-u", "--upload",
+ action="store_true", dest="upload",
+ help="upload files to server")
+ parser.add_option("-t", "--timeout",
+ type="int", dest="timeout", default=3600,
+ help="transfer timeout")
+ options, args = parser.parse_args()
+ if options.download == options.upload:
+ parser.error("you must specify either -d or -u")
+ if len(args) != 4:
+ parser.error("incorrect number of arguments")
+ address, port, src_pattern, dst_path = args
+ port = int(port)
+
+ if options.download:
+ client = FileDownloadClient(address, port)
+ client.download(src_pattern, dst_path, timeout=options.timeout)
+ client.close()
+ elif options.upload:
+ client = FileUploadClient(address, port)
+ client.upload(src_pattern, dst_path, timeout=options.timeout)
+ client.close()
+
+
+if __name__ == "__main__":
+ main()