Advertisement
CSenshi

CN - HW_Firewall (Stateless)

Jan 2nd, 2020
394
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.68 KB | None | 0 0
  1. #!/usr/bin/env python
  2.  
  3. from main import PKT_DIR_INCOMING, PKT_DIR_OUTGOING
  4.  
  5. import socket
  6. import struct
  7.  
  8. class Firewall:
  9.  
  10.     def __init__(self, config, iface_int, iface_ext):
  11.         self.iface_int = iface_int
  12.         self.iface_ext = iface_ext
  13.         rules_filename = config['rule']
  14.  
  15.         # rules[i] = [<verdict>, <protocol>, <external IP address>, <external port>] or
  16.         #            [<verdict>, 'dns', <domain name>]
  17.         rules = []
  18.         # Load the firewall rules (from rules_filename)
  19.         with open(rules_filename) as rules_file:
  20.             for rule_line in rules_file:
  21.                 if rule_line[0] not in ['\n', '%']:
  22.                     rules.append(rule_line.lower().split())
  23.  
  24.         # Save list of rules
  25.         self.rules = rules
  26.  
  27.         # geoipdb[<2-character country code>] = [<start IP address>, <end IP address>]
  28.         geoipdb = {}
  29.         # Load the GeoIP DB ('geoipdb.txt')
  30.         with open('geoipdb.txt') as geoipdb_file:
  31.             for geoip in geoipdb_file:
  32.                 geoip = geoip.split()
  33.                 country_code = geoip[2].upper()
  34.                 if country_code not in geoipdb:
  35.                     geoipdb[country_code] = []
  36.                 geoipdb[country_code].append(tuple(geoip[0:2]))
  37.  
  38.         # Save GeoIP DB
  39.         self.geoipdb = geoipdb
  40.  
  41.     # @packet_dir: either PKT_DIR_INCOMING or PKT_DIR_OUTGOING
  42.     # @packet: the actual data of the IPv4 packet (including IP header)
  43.     def handle_packet(self, packet_dir, packet):
  44.  
  45.         protocol = get_protocol(packet)
  46.         if not protocol:
  47.             self.pass_packet(packet_dir, packet)
  48.             return
  49.  
  50.         external_ip = get_external_ip(packet_dir, packet)
  51.         external_port = get_external_port(packet_dir, packet, protocol)
  52.  
  53.         for rule in self.rules:
  54.             if rule[1] == 'dns':
  55.                 qname = matches_dns_rules(packet_dir, packet, protocol, external_port)
  56.                 if qname and is_subdomain(qname, rule[2]):
  57.                     if rule[0] == 'pass':
  58.                         self.pass_packet(packet_dir, packet)
  59.                     return
  60.             else:
  61.                 if self.matches(rule, protocol, external_ip, external_port):
  62.                     if rule[0] == 'pass':
  63.                         self.pass_packet(packet_dir, packet)
  64.                     return
  65.  
  66.         # If none of the rules match, just pass the packet
  67.         self.pass_packet(packet_dir, packet)
  68.  
  69.     # Returns True if 'rule' matched with passed parameters
  70.     def matches(self, rule, protocol, external_ip, external_port):
  71.         return protocol == rule[1] and \
  72.                 self.matches_external_ip(rule[2], external_ip) and \
  73.                  self.matches_external_port(rule[3], external_port)
  74.  
  75.     # Returns True if 'ext_ip' matches to 'rule_ext_ip'
  76.     def matches_external_ip(self, rule_ext_ip, ext_ip):
  77.  
  78.         if rule_ext_ip in ['any', ext_ip]:
  79.             return True
  80.  
  81.         if len(rule_ext_ip) == 2: # is a 2-byte country code
  82.             ip_ranges = self.geoipdb.get(rule_ext_ip.upper(), None)
  83.             if ip_ranges:
  84.  
  85.                 def to_int(x):
  86.                     return struct.unpack('!I', socket.inet_aton(x))[0]
  87.  
  88.                 ip = to_int(ext_ip)
  89.                 for ip_range in ip_ranges:
  90.                     start, finish = to_int(ip_range[0]), to_int(ip_range[1])
  91.                     if (ip > start - 1) and (ip < finish + 1):
  92.                         return True
  93.  
  94.         elif '/' in rule_ext_ip:
  95.             return is_in_subnet(ext_ip, rule_ext_ip)
  96.  
  97.         return False
  98.  
  99.     # Returns True if 'ext_port' matches to 'rule_ext_port'
  100.     def matches_external_port(self, rule_ext_port, ext_port):
  101.  
  102.         if rule_ext_port in ['any', str(ext_port)]:
  103.             return True
  104.  
  105.         if '-' in rule_ext_port:
  106.             rule_ext_port = list(map(lambda x: int(x), rule_ext_port.split('-')))
  107.             return (ext_port > rule_ext_port[0] - 1) and (ext_port < rule_ext_port[1] + 1)
  108.  
  109.         return False
  110.  
  111.     # Sends packet in specific direction
  112.     def pass_packet(self, packet_dir, packet):
  113.         if packet_dir == PKT_DIR_INCOMING:
  114.             self.iface_int.send_ip_packet(packet)
  115.         elif packet_dir == PKT_DIR_OUTGOING:
  116.             self.iface_ext.send_ip_packet(packet)
  117.  
  118. # Returns protocol as string
  119. # or None if such does not specified in protocol_switch
  120. def get_protocol(packet):
  121.     protocol_switch = { 1 : 'icmp', 6 : 'tcp', 17 : 'udp' }
  122.     protocol = struct.unpack('!B', packet[9])[0]
  123.     return protocol_switch.get(protocol, None)
  124.  
  125. # Returns external ip (not local) of packet
  126. def get_external_ip(packet_dir, packet):
  127.     if packet_dir == PKT_DIR_INCOMING:
  128.         external_ip = packet[12:16] # source
  129.     elif packet_dir == PKT_DIR_OUTGOING:
  130.         external_ip = packet[16:20] # destination
  131.     # Convert Internet host address to an Internet dot address and return
  132.     return socket.inet_ntoa(external_ip)
  133.  
  134. # Returns port of external host
  135. def get_external_port(packet_dir, packet, protocol):
  136.     i = get_ip_header_len(packet)
  137.     if protocol == 'icmp':
  138.         port = struct.unpack('!B', packet[i])[0]
  139.     elif protocol in ['tcp', 'udp']:
  140.         if packet_dir == PKT_DIR_OUTGOING: i += 2
  141.         port = struct.unpack('!H', packet[i:i + 2])[0]
  142.     return port
  143.  
  144. # Returns header length specified in packet
  145. def get_ip_header_len(packet):
  146.     return (struct.unpack('!B', packet[0])[0] & 15) * 4
  147.  
  148. # Returns True if ip is in subnet
  149. def is_in_subnet(ip, subnet):
  150.     subnet = subnet.split('/')
  151.     subnet[1] = int(subnet[1])
  152.     start = struct.unpack('!I', socket.inet_aton(subnet[0]))[0] & \
  153.             int(subnet[1] * '1' + (32 - subnet[1]) * '0', 2)
  154.     finish =  start + int((32 - subnet[1]) * '1', 2)
  155.     ip = struct.unpack('!I', socket.inet_aton(ip))[0]
  156.     return (ip > start - 1) and (ip < finish + 1)
  157.  
  158. # Returns True if passed parameters matches with dns_rules
  159. def matches_dns_rules(packet_dir, packet, protocol, port):
  160.  
  161.     if packet_dir != PKT_DIR_OUTGOING or \
  162.         protocol != 'udp' or port != 53:
  163.         return False
  164.  
  165.     # 12 = length of udp header (8) + length before QDCOUNT entry in DNS Message (4)
  166.     i = get_ip_header_len(packet)
  167.     i += 12
  168.     try: qdcount = struct.unpack('!H', packet[i:i + 2])[0]
  169.     except: return False
  170.  
  171.     if qdcount != 1:
  172.         return False
  173.  
  174.     # length of DNS Message Header (12) -
  175.     # length from transport layer before QDCOUNT entry in DNS Message (4)
  176.     i += 8
  177.     qname = ''
  178.     while True:
  179.         try:
  180.             if packet[i] == '\x00':
  181.                 break
  182.         except:
  183.             return False
  184.         offset = struct.unpack('!B', packet[i])[0]
  185.         qname += packet[i + 1:i + 1 + offset] +  '.'
  186.         i += offset + 1
  187.     qname = qname[:-1]
  188.  
  189.     # Jump to QTYPE Section
  190.     i += 1
  191.     try:
  192.         qtype = struct.unpack('!H', packet[i:i + 2])[0]
  193.     except:
  194.         return False
  195.  
  196.     if qtype not in [1, 28]:
  197.         return False
  198.  
  199.     # Jump to QCLASS Section
  200.     i += 2
  201.     try:
  202.         qclass = struct.unpack('!H', packet[i:i + 2])[0]
  203.     except:
  204.         return False
  205.  
  206.     if qclass != 1:
  207.         return False
  208.  
  209.     return qname
  210.  
  211. # Returns True if child is subdomain of parent
  212. def is_subdomain(child, parent):
  213.     if parent == '' and child != '': # Edge case
  214.         return False
  215.     i = get_common_suffix_len(child, parent)
  216.     if i != 0: parent = parent[:-i]
  217.     return parent in ['', '*']
  218.  
  219. # Returns max length of common suffix
  220. def get_common_suffix_len(str1, str2):
  221.     cmp_fn = lambda x: len(x)
  222.     s1 = min([str1, str2], key = cmp_fn)
  223.     s2 = max([str2, str1], key = cmp_fn)
  224.     i = -1
  225.     while -i != len(s1) + 1 and \
  226.         s1[i] == s2[i]:
  227.         i -= 1
  228.     return -(i + 1)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement