# Copyright (c) 2011 - 2017, Intel Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""``sshtun.py``
`Setup ssh tunnel with local port forwarding`
"""
import select
import socket
import socketserver
from threading import Thread
import time
import paramiko
from . import loggers
[docs]def get_local_port():
"""Get port.
"""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0))
_port = sock.getsockname()[1]
sock.close()
del sock
return _port
[docs]class ForwardHandlerMixin(socketserver.BaseRequestHandler):
"""Base class of port forwarding handler.
"""
class_logger = loggers.ClassLogger()
[docs] def setup(self):
"""Set connection timeout.
"""
self.request.settimeout(3)
[docs] def handle(self):
"""Processing incoming request.
"""
peername = self.request.getpeername()
try:
channel = self.transport.open_channel("direct-tcpip",
(self.remote_host, self.remote_port),
peername)
except Exception as err:
self.class_logger.warning("Failed to process incoming request to {0}:{1}. Error: {2}".
format(self.remote_host, self.remote_port, err))
return
if channel is None:
self.class_logger.warning("Incoming request to {0}:{1} is rejected by the SSH server.".
format(self.remote_host, self.remote_port))
return
ch_peername = channel.getpeername()
self.class_logger.debug("Processing request: " + str(peername[0]) + ":" + str(peername[1]) +
"->" + str(ch_peername[0]) + ":" + str(ch_peername[1]) +
"->" + str(self.remote_host) + ":" + str(self.remote_port))
__timeout = 350
end_time = time.time() + __timeout
while True:
if time.time() > end_time:
self.class_logger.error("Request timed out: " + str(peername[0]) + ":" + str(peername[1]))
break
rlist, _, _ = select.select([self.request, channel], [], [], 60)
if self.request in rlist:
data = self.request.recv(1024)
if len(data) == 0:
break
channel.send(data)
end_time = time.time() + __timeout
if channel in rlist:
data = channel.recv(1024)
if len(data) == 0:
break
self.request.send(data)
end_time = time.time() + __timeout
channel.close()
self.request.close()
self.class_logger.debug("Request is processed: " + str(peername[0]) + ":" + str(peername[1]))
[docs]class ForwardServer(socketserver.TCPServer):
"""Preconfigured SocketServer.TCPServer.
"""
daemon_threads = True
allow_reuse_address = True
[docs]class SSHTunnel(object):
"""Main class for creating ssh tunnel.
"""
class_logger = loggers.ClassLogger()
[docs] def __init__(self, server, user, passwd, remote, local_port=None, local_host="127.0.0.1"):
"""Initialize SSHTunnel class.
Args:
server(list | tuple): Server information in format [ip, port]
user(str): Username
passwd(str): Password
remote(list | tuple): Remote server information in format [ip, port]
local_port(int): Local port assigned for forwarding
local_host(str): Local IP to listen on, defaults to 127.0.0.1
"""
self.srv = server
self.usr = user
self.passwd = passwd
self.remote = remote
self.local = local_port
self.local_port = None
self.local_host = local_host
self.server = None
self.transport = None
self.ssh_client = None
self.thr = None
[docs] def fwdport(self):
"""Launch port forwarding server.
"""
# Taken from paramiko examples.
# SocketServer doesn't give Handlers any way to access the outer server normally.
class FHandler(ForwardHandlerMixin):
"""Get configuration of forwarding server.
"""
remote_host = self.remote[0]
remote_port = self.remote[1]
transport = self.transport
class_logger = loggers.ClassLogger()
self.local_port = self.local or get_local_port()
self.class_logger.debug("Try to setup port forwarding: {0} -> {1}:{2} ...".
format(self.local_port, self.remote[0], self.remote[1]))
self.server = ForwardServer((self.local_host, self.local_port), FHandler)
self.server.serve_forever()
[docs] def connect(self):
"""Perform ssh connection.
Raises:
Exception: error on connect
"""
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.class_logger.info("Establishing ssh connection to host {0}:{1} ...".format(self.srv[0], self.srv[1]))
try:
client.connect(self.srv[0], self.srv[1],
username=self.usr, password=self.passwd,
key_filename=None, look_for_keys=False)
self.class_logger.info("SSH connection is established.")
except Exception as err:
self.class_logger.error("Failed to connect to {0}:{1}. Error: {2}".
format(self.srv[0], self.srv[1], err))
self.stop()
raise
self.ssh_client = client
self.transport = client.get_transport()
# Set keepalive to be aware if transport is active.
self.transport.set_keepalive(15)
[docs] def start(self, timeout=60):
"""Perform ssh connection and start port forwarding server in thread.
Args:
timeout(int): Port forwarding configuration timeout
Raises:
Exception: timeout exceeded on start
"""
self.connect()
if self.ssh_client and self.transport:
self.thr = Thread(target=self.fwdport)
self.thr.start()
# Add timeout to fwdport function.
end_time = time.time() + timeout
while self.server is None:
if time.time() < end_time:
time.sleep(0.1)
else:
raise Exception("SSH tunnel TCP Server isn't started in {0} seconds.".format(timeout))
else:
self.ssh_client = None
self.transport = None
[docs] def stop(self):
"""Stop port forwarding server and thread.
"""
if self.server is not None:
self.class_logger.debug("Stop port forwarding: {0} -> {1}:{2}".
format(self.local_port, self.remote[0], self.remote[1]))
self.server.shutdown()
self.local_port = None
self.server = None
self.ssh_client.close()
self.transport = None
self.ssh_client = None
self.class_logger.info("Stopping fwdport thread ...")
if self.thr is not None and not self.thr.is_alive():
self.class_logger.info("fwdport thread isn't started. Skipping stop procedures.")
return
if self.thr is not None:
self.thr.join()
self.thr = None
[docs] def establish(self):
"""Start sshtun server and wait while connection is established.
Raises:
Exception: timeout exceeded
Returns:
int: local port
"""
if self.thr is not None and self.thr.is_alive():
self.class_logger.info("fwdport thread already started. Checking ...")
if self.check():
self.class_logger.info("Transport is active. Skip start tunnel procedures.")
return
else:
self.close()
self.class_logger.info("Starting fwdport thread ...")
self.start()
end_time = time.time() + 60
while True:
if self.local_port is not None:
time.sleep(0.5)
break
if time.time() > end_time:
self.stop()
raise Exception("Timeout exceeded. Local port isn't set and port forwarding isn't established.")
time.sleep(0.21)
return self.server.socket.getsockname()[1]
[docs] def close(self):
"""Close sshtun server.
"""
self.class_logger.info("Stopping fwdport thread ...")
self.stop()
[docs] def check(self):
"""Return True if connection is established.
"""
if self.transport:
return self.transport.is_active()
else:
return False
[docs] def __del__(self):
"""Try to close connection on object destroy.
"""
self.class_logger.debug("SSHTunnel object has to be deleted. Try to stop forwarding if it is active.")
self.stop()