Advertisement
CSenshi

CN - HW_Firewall (Statefull)

Jan 2nd, 2020
383
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 14.47 KB | None | 0 0
  1. #!/usr/bin/env python
  2.  
  3. from main import PKT_DIR_INCOMING, PKT_DIR_OUTGOING
  4.  
  5. import socket
  6. import struct
  7.  
  8. class Firewall:
  9.  
  10.     def __init__(self, config, iface_int, iface_ext):
  11.         self.iface_int = iface_int
  12.         self.iface_ext = iface_ext
  13.         rules_filename = config['rule']
  14.  
  15.         # rules[i] = [<verdict>, <protocol>, <external IP address>, <external port>] or
  16.         #            [<verdict>, 'dns', <domain name>]
  17.         rules = []
  18.         # Load the firewall rules (from rules_filename)
  19.         with open(rules_filename) as rules_file:
  20.             for rule_line in rules_file:
  21.                 if rule_line[0] not in ['\n', '%']:
  22.                     rules.append(rule_line.lower().split())
  23.  
  24.         # Save list of rules
  25.         self.rules = rules
  26.  
  27.         # geoipdb[<2-character country code>] = [<start IP address>, <end IP address>]
  28.         geoipdb = {}
  29.         # Load the GeoIP DB ('geoipdb.txt')
  30.         with open('geoipdb.txt') as geoipdb_file:
  31.             for geoip in geoipdb_file:
  32.                 geoip = geoip.split()
  33.                 country_code = geoip[2].upper()
  34.                 if country_code not in geoipdb:
  35.                     geoipdb[country_code] = []
  36.                 geoipdb[country_code].append(tuple(geoip[0:2]))
  37.  
  38.         # Save GeoIP DB
  39.         self.geoipdb = geoipdb
  40.  
  41.     # @packet_dir: either PKT_DIR_INCOMING or PKT_DIR_OUTGOING
  42.     # @packet: the actual data of the IPv4 packet (including IP header)
  43.     def handle_packet(self, packet_dir, packet):
  44.        
  45.         protocol = get_protocol(packet)
  46.         if not protocol:
  47.             self.pass_packet(packet_dir, packet)
  48.             return
  49.  
  50.         external_ip = get_external_ip(packet_dir, packet)
  51.         external_port = get_external_port(packet_dir, packet, protocol)
  52.  
  53.         for rule in self.rules:
  54.             if rule[1] == 'dns':
  55.                 qname = matches_dns_rules(packet_dir, packet, protocol, external_port)
  56.                 if qname and is_subdomain(qname, rule[2]):
  57.  
  58.                     if rule[0] == 'pass':
  59.                         self.pass_packet(packet_dir, packet)
  60.  
  61.                     elif rule[0] == 'deny':
  62.                         response = get_dns_response(packet)
  63.                         if response:
  64.                             self.reflect_packet(packet_dir, response)
  65.  
  66.                     return
  67.  
  68.             elif rule[1] == 'http':
  69.                 try_to_log(rule[2], protocol, external_ip, external_port, packet_dir, packet)
  70.  
  71.             else:
  72.                 if self.matches(rule, protocol, external_ip, external_port):
  73.  
  74.                     if rule[0] == 'pass':
  75.                         self.pass_packet(packet_dir, packet)
  76.  
  77.                     elif rule[0] == 'deny' and rule[1] == 'tcp':
  78.                         response = get_rst_packet(packet)
  79.                         self.reflect_packet(packet_dir, response)
  80.  
  81.                     return
  82.  
  83.         # If none of the rules match, just pass the packet
  84.         self.pass_packet(packet_dir, packet)
  85.  
  86.     # Returns True if 'rule' matched with passed parameters
  87.     def matches(self, rule, protocol, external_ip, external_port):
  88.         return protocol == rule[1] and \
  89.                 self.matches_external_ip(rule[2], external_ip) and \
  90.                  self.matches_external_port(rule[3], external_port)
  91.  
  92.     # Returns True if 'ext_ip' matches to 'rule_ext_ip'
  93.     def matches_external_ip(self, rule_ext_ip, ext_ip):
  94.  
  95.         if rule_ext_ip in ['any', ext_ip]:
  96.             return True
  97.  
  98.         if len(rule_ext_ip) == 2: # is a 2-byte country code
  99.             ip_ranges = self.geoipdb.get(rule_ext_ip.upper(), None)
  100.             if ip_ranges:
  101.  
  102.                 def to_int(x):
  103.                     return struct.unpack('!I', socket.inet_aton(x))[0]
  104.  
  105.                 ip = to_int(ext_ip)
  106.                 for ip_range in ip_ranges:
  107.                     start, finish = to_int(ip_range[0]), to_int(ip_range[1])
  108.                     if (ip > start - 1) and (ip < finish + 1):
  109.                         return True
  110.  
  111.         elif '/' in rule_ext_ip:
  112.             return is_in_subnet(ext_ip, rule_ext_ip)
  113.  
  114.         return False
  115.  
  116.     # Returns True if 'ext_port' matches to 'rule_ext_port'
  117.     def matches_external_port(self, rule_ext_port, ext_port):
  118.  
  119.         if rule_ext_port in ['any', str(ext_port)]:
  120.             return True
  121.  
  122.         if '-' in rule_ext_port:
  123.             rule_ext_port = list(map(lambda x: int(x), rule_ext_port.split('-')))
  124.             return (ext_port > rule_ext_port[0] - 1) and (ext_port < rule_ext_port[1] + 1)
  125.  
  126.         return False
  127.  
  128.     # Sends packet in specific direction
  129.     def pass_packet(self, packet_dir, packet):
  130.         if packet_dir == PKT_DIR_INCOMING:
  131.             self.iface_int.send_ip_packet(packet)
  132.         elif packet_dir == PKT_DIR_OUTGOING:
  133.             self.iface_ext.send_ip_packet(packet)
  134.  
  135.     # Reflects packet at same interface
  136.     def reflect_packet(self, packet_dir, packet):
  137.         if packet_dir == PKT_DIR_INCOMING:
  138.             self.iface_ext.send_ip_packet(packet)
  139.         elif packet_dir == PKT_DIR_OUTGOING:
  140.             self.iface_int.send_ip_packet(packet)
  141.  
  142. # Returns protocol as string
  143. # or None if such does not specified in protocol_switch
  144. def get_protocol(packet):
  145.     protocol_switch = { 1 : 'icmp', 6 : 'tcp', 17 : 'udp' }
  146.     protocol = struct.unpack('!B', packet[9])[0]
  147.     return protocol_switch.get(protocol, None)
  148.  
  149. # Returns external ip (not local) of packet
  150. def get_external_ip(packet_dir, packet):
  151.     if packet_dir == PKT_DIR_INCOMING:
  152.         external_ip = packet[12:16] # source
  153.     elif packet_dir == PKT_DIR_OUTGOING:
  154.         external_ip = packet[16:20] # destination
  155.     # Convert Internet host address to an Internet dot address and return
  156.     return socket.inet_ntoa(external_ip)
  157.  
  158. # Returns port of external host
  159. def get_external_port(packet_dir, packet, protocol):
  160.     i = get_ip_header_len(packet)
  161.     if protocol == 'icmp':
  162.         port = struct.unpack('!B', packet[i])[0]
  163.     elif protocol in ['tcp', 'udp']:
  164.         if packet_dir == PKT_DIR_OUTGOING: i += 2
  165.         port = struct.unpack('!H', packet[i:i + 2])[0]
  166.     return port
  167.  
  168. # Returns header length specified in packet
  169. def get_ip_header_len(packet):
  170.     return (struct.unpack('!B', packet[0])[0] & 15) * 4
  171.  
  172. # Returns True if ip is in subnet
  173. def is_in_subnet(ip, subnet):
  174.     subnet = subnet.split('/')
  175.     subnet[1] = int(subnet[1])
  176.     start = struct.unpack('!I', socket.inet_aton(subnet[0]))[0] & \
  177.             int(subnet[1] * '1' + (32 - subnet[1]) * '0', 2)
  178.     finish =  start + int((32 - subnet[1]) * '1', 2)
  179.     ip = struct.unpack('!I', socket.inet_aton(ip))[0]
  180.     return (ip > start - 1) and (ip < finish + 1)
  181.  
  182. # Returns True if passed parameters matches with dns_rules
  183. def matches_dns_rules(packet_dir, packet, protocol, port):
  184.  
  185.     if packet_dir != PKT_DIR_OUTGOING or \
  186.         protocol != 'udp' or port != 53:
  187.         return False
  188.  
  189.     # 12 = length of udp header (8) + length before QDCOUNT entry in DNS Message (4)
  190.     i = get_ip_header_len(packet)
  191.     i += 12
  192.     try: qdcount = struct.unpack('!H', packet[i:i + 2])[0]
  193.     except: return False
  194.  
  195.     if qdcount != 1:
  196.         return False
  197.  
  198.     # length of DNS Message Header (12) -
  199.     # length from transport layer before QDCOUNT entry in DNS Message (4)
  200.     i += 8
  201.     qname = ''
  202.     while True:
  203.         try:
  204.             if packet[i] == '\x00':
  205.                 break
  206.         except:
  207.             return False
  208.         offset = struct.unpack('!B', packet[i])[0]
  209.         qname += packet[i + 1:i + 1 + offset] +  '.'
  210.         i += offset + 1
  211.     qname = qname[:-1]
  212.  
  213.     # Jump to QTYPE Section
  214.     i += 1
  215.     try:
  216.         qtype = struct.unpack('!H', packet[i:i + 2])[0]
  217.     except:
  218.         return False
  219.  
  220.     if qtype not in [1, 28]:
  221.         return False
  222.  
  223.     # Jump to QCLASS Section
  224.     i += 2
  225.     try:
  226.         qclass = struct.unpack('!H', packet[i:i + 2])[0]
  227.     except:
  228.         return False
  229.  
  230.     if qclass != 1:
  231.         return False
  232.  
  233.     return qname
  234.  
  235. # Returns True if child is subdomain of parent
  236. def is_subdomain(child, parent):
  237.     if parent == '' and child != '': # Edge case
  238.         return False
  239.     i = get_common_suffix_len(child, parent)
  240.     if i != 0: parent = parent[:-i]
  241.     return parent in ['', '*']
  242.  
  243. # Returns max length of common suffix
  244. def get_common_suffix_len(str1, str2):
  245.     cmp_fn = lambda x: len(x)
  246.     s1 = min([str1, str2], key = cmp_fn)
  247.     s2 = max([str2, str1], key = cmp_fn)
  248.     i = -1
  249.     while -i != len(s1) + 1 and \
  250.         s1[i] == s2[i]:
  251.         i -= 1
  252.     return -(i + 1)
  253.  
  254. # Returns dns response for denied dns request
  255. def get_dns_response(packet):
  256.  
  257.     # De-encapsulate dns message
  258.     ip_header_len = get_ip_header_len(packet)
  259.     dns_response = packet[ip_header_len + 8:]
  260.  
  261.     i = 12
  262.     while dns_response[i] != '\x00':
  263.         i += struct.unpack('!B', dns_response[i])[0] + 1
  264.     i += 1
  265.  
  266.     qtype = struct.unpack('!H', dns_response[i:i + 2])[0]
  267.     if qtype == 28:
  268.         return None
  269.  
  270.     second = struct.pack('!B', (struct.unpack('!B', dns_response[2])[0] & 127) + 128)
  271.     dns_response = dns_response[0:2] + second + dns_response[3] + \
  272.         '\x00\x01\x00\x01\x00\x00\x00\x00' + dns_response[12: i + 4] + \
  273.         dns_response[12: i] + \
  274.         '\x00\x01\x00\x01\x00\x00\x00\x01\x00\x04\xa9\xe5\x31\x82'
  275.  
  276.     # Construct udp header
  277.     old_ip_header = packet[:ip_header_len]
  278.     old_udp_header = packet[ip_header_len: ip_header_len + 8]
  279.     udp_header = construct_udp_header(old_ip_header, old_udp_header, dns_response)
  280.  
  281.     # Construct segment
  282.     segment = udp_header + dns_response
  283.  
  284.     # Construct ip header
  285.     ip_header = construct_ip_header(old_ip_header, len(segment))
  286.  
  287.     # Construct new packet
  288.     new_packet = ip_header + segment
  289.     return new_packet
  290.  
  291. # Costructs udp header depending on old one
  292. def construct_udp_header(old_ip_header, old_udp_header, message):
  293.  
  294.     # 2) De-encapsulate old udp header
  295.     src_port = old_udp_header[2:4]
  296.     dst_port = old_udp_header[0:2]
  297.     length = 8 + len(message)
  298.  
  299.     # Construct new udp header
  300.     udp_header = src_port + dst_port + struct.pack('!H', length) + '\x00\x00'
  301.     pseudo_header = get_pseudo_header(old_ip_header, length)
  302.     checksum = get_checksum(pseudo_header + udp_header + message)
  303.  
  304.     udp_header = udp_header[:6] + checksum
  305.     return udp_header
  306.  
  307. # Returns pseudo header for calculating checksum
  308. def get_pseudo_header(old_ip_header, segment_len):
  309.  
  310.     src_ip = old_ip_header[16:20]
  311.     dst_ip = old_ip_header[12:16]
  312.     protocol = old_ip_header[9]
  313.     segment_len = struct.pack('!H', segment_len)
  314.  
  315.     pseudo_header = src_ip + dst_ip + '\x00' + protocol + segment_len
  316.     return pseudo_header
  317.  
  318. # Costructs ip header depending on old one
  319. def construct_ip_header(old_ip_header, segment_len):
  320.  
  321.     src_ip = old_ip_header[16:20]
  322.     dst_ip = old_ip_header[12:16]
  323.     protocol = old_ip_header[9]
  324.  
  325.     # Construct new IP header
  326.     length = 20 + segment_len
  327.     length = struct.pack('!H', length)
  328.  
  329.     ip_header = '\x45\x00' + length + \
  330.         '\x11\x11\x00\x00\x01' + protocol + '\x00\x00' + src_ip + dst_ip
  331.  
  332.     checksum = get_checksum(ip_header)
  333.  
  334.     ip_header = ip_header[:10] + checksum + ip_header[12:]
  335.     return ip_header
  336.  
  337. # Returns appropriate tcp RST packet
  338. def get_rst_packet(packet):
  339.  
  340.     old_ip_header_len = get_ip_header_len(packet)
  341.  
  342.     old_ip_header = packet[0: old_ip_header_len]
  343.     old_tcp_header = packet[old_ip_header_len:]
  344.  
  345.     tcp_header = get_rst_tcp_header(old_ip_header, old_tcp_header)
  346.     ip_header = construct_ip_header(old_ip_header, len(tcp_header))
  347.  
  348.     new_packet = ip_header + tcp_header
  349.     return new_packet
  350.  
  351. # Returns RST tcp header
  352. def get_rst_tcp_header(old_ip_header, old_tcp_header):
  353.  
  354.     # De-encapsulate old_tcp_header
  355.     src_port = old_tcp_header[2:4]
  356.     dst_port = old_tcp_header[0:2]
  357.     seq_num = old_tcp_header[8:12]
  358.     ack = old_tcp_header[4:8]
  359.  
  360.     tcp_header = src_port + dst_port + seq_num + ack + \
  361.         '\x50\x14\x00\x00\x00\x00\x00\x00'
  362.  
  363.     # Construct pseudo header
  364.     pseudo_header = get_pseudo_header(old_ip_header, len(tcp_header))
  365.     checksum = get_checksum(pseudo_header + tcp_header)
  366.  
  367.     tcp_header = tcp_header[:16] + checksum + tcp_header[18:]
  368.     return tcp_header
  369.  
  370. # Logs if needed
  371. def try_to_log(rule_host_name, protocol, external_ip, external_port, packet_dir, packet):
  372.  
  373.     if packet_dir != PKT_DIR_OUTGOING or \
  374.         protocol != 'tcp' or external_port != 80:
  375.         return
  376.  
  377.     start = packet.find('Host')
  378.  
  379.     if start == -1:
  380.         host_name = external_ip
  381.         matches = (rule_host_name == host_name)
  382.  
  383.     else:
  384.         finish = packet.find('\r\n', start)
  385.         host_name = packet[start + 6: finish]
  386.         matches = is_subdomain(host_name, rule_host_name)
  387.  
  388.     if not matches:
  389.         return
  390.  
  391.     start, method = packet.find('GET'), 'GET'
  392.     if start == -1:
  393.         for m in ['POST', 'HEAD']:
  394.             start = packet.find(m)
  395.             if start != -1:
  396.                 method = m
  397.                 break
  398.  
  399.     if start == -1: path = '/'
  400.     else:
  401.         start += len(method) + 1
  402.         finish = packet.find(' ', start)
  403.         path = packet[start: finish]
  404.  
  405.     version = 'HTTP/1.1'
  406.     status_code = '200'
  407.     object_size = str(len(packet) - 40)
  408.  
  409.     data = [host_name, method, path, version, status_code, object_size]
  410.     to_log = ' '.join(data) + '\r\n'
  411.  
  412.     f = open('http.log', 'a')
  413.     f.write(to_log)
  414.     f.flush()
  415.     f.close()
  416.  
  417. # Returns checksum of passed data
  418. def get_checksum(data):
  419.  
  420.     checksum = '0'
  421.  
  422.     while len(data) > 1:
  423.         x = data[:2]
  424.  
  425.         x0 = bin(ord(x[0]))[2:]
  426.         x1 = bin(ord(x[1]))[2:]
  427.  
  428.         if len(x1) < 8:
  429.             x1 = (8 - len(x1)) * '0' + x1
  430.  
  431.         x = x0 + x1
  432.  
  433.         checksum = add(checksum, x)
  434.         data = data[2:]
  435.  
  436.     if len(data) == 1:
  437.         x = bin(ord(data[0]))[2:] + 8 * '0'
  438.         checksum = add(checksum, x)
  439.  
  440.     while len(checksum) > 16:
  441.         left = checksum[:len(checksum) - 16]
  442.         right = checksum[len(checksum) - 16:]
  443.         checksum = add(left, right)
  444.  
  445.     if len(checksum) < 16:
  446.         checksum = (16 - len(checksum)) * '0' + checksum
  447.  
  448.     reverse = ''
  449.     for i in range(16):
  450.         if checksum[i] == '0':
  451.             reverse += '1'
  452.         else: # if checksum[i] == 1
  453.             reverse += '0'
  454.  
  455.     checksum = reverse
  456.     checksum = int(checksum, 2)
  457.     checksum = struct.pack('!H', checksum)
  458.  
  459.     return checksum
  460.  
  461. # Adds to binary strings
  462. def add(a, b):
  463.     return bin(int(a, 2) + int(b, 2))[2:]
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement