Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #!/usr/bin/env python
- from main import PKT_DIR_INCOMING, PKT_DIR_OUTGOING
- import socket
- import struct
- class Firewall:
- def __init__(self, config, iface_int, iface_ext):
- self.iface_int = iface_int
- self.iface_ext = iface_ext
- rules_filename = config['rule']
- # rules[i] = [<verdict>, <protocol>, <external IP address>, <external port>] or
- # [<verdict>, 'dns', <domain name>]
- rules = []
- # Load the firewall rules (from rules_filename)
- with open(rules_filename) as rules_file:
- for rule_line in rules_file:
- if rule_line[0] not in ['\n', '%']:
- rules.append(rule_line.lower().split())
- # Save list of rules
- self.rules = rules
- # geoipdb[<2-character country code>] = [<start IP address>, <end IP address>]
- geoipdb = {}
- # Load the GeoIP DB ('geoipdb.txt')
- with open('geoipdb.txt') as geoipdb_file:
- for geoip in geoipdb_file:
- geoip = geoip.split()
- country_code = geoip[2].upper()
- if country_code not in geoipdb:
- geoipdb[country_code] = []
- geoipdb[country_code].append(tuple(geoip[0:2]))
- # Save GeoIP DB
- self.geoipdb = geoipdb
- # @packet_dir: either PKT_DIR_INCOMING or PKT_DIR_OUTGOING
- # @packet: the actual data of the IPv4 packet (including IP header)
- def handle_packet(self, packet_dir, packet):
- protocol = get_protocol(packet)
- if not protocol:
- self.pass_packet(packet_dir, packet)
- return
- external_ip = get_external_ip(packet_dir, packet)
- external_port = get_external_port(packet_dir, packet, protocol)
- for rule in self.rules:
- if rule[1] == 'dns':
- qname = matches_dns_rules(packet_dir, packet, protocol, external_port)
- if qname and is_subdomain(qname, rule[2]):
- if rule[0] == 'pass':
- self.pass_packet(packet_dir, packet)
- return
- else:
- if self.matches(rule, protocol, external_ip, external_port):
- if rule[0] == 'pass':
- self.pass_packet(packet_dir, packet)
- return
- # If none of the rules match, just pass the packet
- self.pass_packet(packet_dir, packet)
- # Returns True if 'rule' matched with passed parameters
- def matches(self, rule, protocol, external_ip, external_port):
- return protocol == rule[1] and \
- self.matches_external_ip(rule[2], external_ip) and \
- self.matches_external_port(rule[3], external_port)
- # Returns True if 'ext_ip' matches to 'rule_ext_ip'
- def matches_external_ip(self, rule_ext_ip, ext_ip):
- if rule_ext_ip in ['any', ext_ip]:
- return True
- if len(rule_ext_ip) == 2: # is a 2-byte country code
- ip_ranges = self.geoipdb.get(rule_ext_ip.upper(), None)
- if ip_ranges:
- def to_int(x):
- return struct.unpack('!I', socket.inet_aton(x))[0]
- ip = to_int(ext_ip)
- for ip_range in ip_ranges:
- start, finish = to_int(ip_range[0]), to_int(ip_range[1])
- if (ip > start - 1) and (ip < finish + 1):
- return True
- elif '/' in rule_ext_ip:
- return is_in_subnet(ext_ip, rule_ext_ip)
- return False
- # Returns True if 'ext_port' matches to 'rule_ext_port'
- def matches_external_port(self, rule_ext_port, ext_port):
- if rule_ext_port in ['any', str(ext_port)]:
- return True
- if '-' in rule_ext_port:
- rule_ext_port = list(map(lambda x: int(x), rule_ext_port.split('-')))
- return (ext_port > rule_ext_port[0] - 1) and (ext_port < rule_ext_port[1] + 1)
- return False
- # Sends packet in specific direction
- def pass_packet(self, packet_dir, packet):
- if packet_dir == PKT_DIR_INCOMING:
- self.iface_int.send_ip_packet(packet)
- elif packet_dir == PKT_DIR_OUTGOING:
- self.iface_ext.send_ip_packet(packet)
- # Returns protocol as string
- # or None if such does not specified in protocol_switch
- def get_protocol(packet):
- protocol_switch = { 1 : 'icmp', 6 : 'tcp', 17 : 'udp' }
- protocol = struct.unpack('!B', packet[9])[0]
- return protocol_switch.get(protocol, None)
- # Returns external ip (not local) of packet
- def get_external_ip(packet_dir, packet):
- if packet_dir == PKT_DIR_INCOMING:
- external_ip = packet[12:16] # source
- elif packet_dir == PKT_DIR_OUTGOING:
- external_ip = packet[16:20] # destination
- # Convert Internet host address to an Internet dot address and return
- return socket.inet_ntoa(external_ip)
- # Returns port of external host
- def get_external_port(packet_dir, packet, protocol):
- i = get_ip_header_len(packet)
- if protocol == 'icmp':
- port = struct.unpack('!B', packet[i])[0]
- elif protocol in ['tcp', 'udp']:
- if packet_dir == PKT_DIR_OUTGOING: i += 2
- port = struct.unpack('!H', packet[i:i + 2])[0]
- return port
- # Returns header length specified in packet
- def get_ip_header_len(packet):
- return (struct.unpack('!B', packet[0])[0] & 15) * 4
- # Returns True if ip is in subnet
- def is_in_subnet(ip, subnet):
- subnet = subnet.split('/')
- subnet[1] = int(subnet[1])
- start = struct.unpack('!I', socket.inet_aton(subnet[0]))[0] & \
- int(subnet[1] * '1' + (32 - subnet[1]) * '0', 2)
- finish = start + int((32 - subnet[1]) * '1', 2)
- ip = struct.unpack('!I', socket.inet_aton(ip))[0]
- return (ip > start - 1) and (ip < finish + 1)
- # Returns True if passed parameters matches with dns_rules
- def matches_dns_rules(packet_dir, packet, protocol, port):
- if packet_dir != PKT_DIR_OUTGOING or \
- protocol != 'udp' or port != 53:
- return False
- # 12 = length of udp header (8) + length before QDCOUNT entry in DNS Message (4)
- i = get_ip_header_len(packet)
- i += 12
- try: qdcount = struct.unpack('!H', packet[i:i + 2])[0]
- except: return False
- if qdcount != 1:
- return False
- # length of DNS Message Header (12) -
- # length from transport layer before QDCOUNT entry in DNS Message (4)
- i += 8
- qname = ''
- while True:
- try:
- if packet[i] == '\x00':
- break
- except:
- return False
- offset = struct.unpack('!B', packet[i])[0]
- qname += packet[i + 1:i + 1 + offset] + '.'
- i += offset + 1
- qname = qname[:-1]
- # Jump to QTYPE Section
- i += 1
- try:
- qtype = struct.unpack('!H', packet[i:i + 2])[0]
- except:
- return False
- if qtype not in [1, 28]:
- return False
- # Jump to QCLASS Section
- i += 2
- try:
- qclass = struct.unpack('!H', packet[i:i + 2])[0]
- except:
- return False
- if qclass != 1:
- return False
- return qname
- # Returns True if child is subdomain of parent
- def is_subdomain(child, parent):
- if parent == '' and child != '': # Edge case
- return False
- i = get_common_suffix_len(child, parent)
- if i != 0: parent = parent[:-i]
- return parent in ['', '*']
- # Returns max length of common suffix
- def get_common_suffix_len(str1, str2):
- cmp_fn = lambda x: len(x)
- s1 = min([str1, str2], key = cmp_fn)
- s2 = max([str2, str1], key = cmp_fn)
- i = -1
- while -i != len(s1) + 1 and \
- s1[i] == s2[i]:
- i -= 1
- return -(i + 1)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement