Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #!/usr/bin/env python
- from main import PKT_DIR_INCOMING, PKT_DIR_OUTGOING
- import socket
- import struct
- class Firewall:
- def __init__(self, config, iface_int, iface_ext):
- self.iface_int = iface_int
- self.iface_ext = iface_ext
- rules_filename = config['rule']
- # rules[i] = [<verdict>, <protocol>, <external IP address>, <external port>] or
- # [<verdict>, 'dns', <domain name>]
- rules = []
- # Load the firewall rules (from rules_filename)
- with open(rules_filename) as rules_file:
- for rule_line in rules_file:
- if rule_line[0] not in ['\n', '%']:
- rules.append(rule_line.lower().split())
- # Save list of rules
- self.rules = rules
- # geoipdb[<2-character country code>] = [<start IP address>, <end IP address>]
- geoipdb = {}
- # Load the GeoIP DB ('geoipdb.txt')
- with open('geoipdb.txt') as geoipdb_file:
- for geoip in geoipdb_file:
- geoip = geoip.split()
- country_code = geoip[2].upper()
- if country_code not in geoipdb:
- geoipdb[country_code] = []
- geoipdb[country_code].append(tuple(geoip[0:2]))
- # Save GeoIP DB
- self.geoipdb = geoipdb
- # @packet_dir: either PKT_DIR_INCOMING or PKT_DIR_OUTGOING
- # @packet: the actual data of the IPv4 packet (including IP header)
- def handle_packet(self, packet_dir, packet):
- protocol = get_protocol(packet)
- if not protocol:
- self.pass_packet(packet_dir, packet)
- return
- external_ip = get_external_ip(packet_dir, packet)
- external_port = get_external_port(packet_dir, packet, protocol)
- for rule in self.rules:
- if rule[1] == 'dns':
- qname = matches_dns_rules(packet_dir, packet, protocol, external_port)
- if qname and is_subdomain(qname, rule[2]):
- if rule[0] == 'pass':
- self.pass_packet(packet_dir, packet)
- elif rule[0] == 'deny':
- response = get_dns_response(packet)
- if response:
- self.reflect_packet(packet_dir, response)
- return
- elif rule[1] == 'http':
- try_to_log(rule[2], protocol, external_ip, external_port, packet_dir, packet)
- else:
- if self.matches(rule, protocol, external_ip, external_port):
- if rule[0] == 'pass':
- self.pass_packet(packet_dir, packet)
- elif rule[0] == 'deny' and rule[1] == 'tcp':
- response = get_rst_packet(packet)
- self.reflect_packet(packet_dir, response)
- return
- # If none of the rules match, just pass the packet
- self.pass_packet(packet_dir, packet)
- # Returns True if 'rule' matched with passed parameters
- def matches(self, rule, protocol, external_ip, external_port):
- return protocol == rule[1] and \
- self.matches_external_ip(rule[2], external_ip) and \
- self.matches_external_port(rule[3], external_port)
- # Returns True if 'ext_ip' matches to 'rule_ext_ip'
- def matches_external_ip(self, rule_ext_ip, ext_ip):
- if rule_ext_ip in ['any', ext_ip]:
- return True
- if len(rule_ext_ip) == 2: # is a 2-byte country code
- ip_ranges = self.geoipdb.get(rule_ext_ip.upper(), None)
- if ip_ranges:
- def to_int(x):
- return struct.unpack('!I', socket.inet_aton(x))[0]
- ip = to_int(ext_ip)
- for ip_range in ip_ranges:
- start, finish = to_int(ip_range[0]), to_int(ip_range[1])
- if (ip > start - 1) and (ip < finish + 1):
- return True
- elif '/' in rule_ext_ip:
- return is_in_subnet(ext_ip, rule_ext_ip)
- return False
- # Returns True if 'ext_port' matches to 'rule_ext_port'
- def matches_external_port(self, rule_ext_port, ext_port):
- if rule_ext_port in ['any', str(ext_port)]:
- return True
- if '-' in rule_ext_port:
- rule_ext_port = list(map(lambda x: int(x), rule_ext_port.split('-')))
- return (ext_port > rule_ext_port[0] - 1) and (ext_port < rule_ext_port[1] + 1)
- return False
- # Sends packet in specific direction
- def pass_packet(self, packet_dir, packet):
- if packet_dir == PKT_DIR_INCOMING:
- self.iface_int.send_ip_packet(packet)
- elif packet_dir == PKT_DIR_OUTGOING:
- self.iface_ext.send_ip_packet(packet)
- # Reflects packet at same interface
- def reflect_packet(self, packet_dir, packet):
- if packet_dir == PKT_DIR_INCOMING:
- self.iface_ext.send_ip_packet(packet)
- elif packet_dir == PKT_DIR_OUTGOING:
- self.iface_int.send_ip_packet(packet)
- # Returns protocol as string
- # or None if such does not specified in protocol_switch
- def get_protocol(packet):
- protocol_switch = { 1 : 'icmp', 6 : 'tcp', 17 : 'udp' }
- protocol = struct.unpack('!B', packet[9])[0]
- return protocol_switch.get(protocol, None)
- # Returns external ip (not local) of packet
- def get_external_ip(packet_dir, packet):
- if packet_dir == PKT_DIR_INCOMING:
- external_ip = packet[12:16] # source
- elif packet_dir == PKT_DIR_OUTGOING:
- external_ip = packet[16:20] # destination
- # Convert Internet host address to an Internet dot address and return
- return socket.inet_ntoa(external_ip)
- # Returns port of external host
- def get_external_port(packet_dir, packet, protocol):
- i = get_ip_header_len(packet)
- if protocol == 'icmp':
- port = struct.unpack('!B', packet[i])[0]
- elif protocol in ['tcp', 'udp']:
- if packet_dir == PKT_DIR_OUTGOING: i += 2
- port = struct.unpack('!H', packet[i:i + 2])[0]
- return port
- # Returns header length specified in packet
- def get_ip_header_len(packet):
- return (struct.unpack('!B', packet[0])[0] & 15) * 4
- # Returns True if ip is in subnet
- def is_in_subnet(ip, subnet):
- subnet = subnet.split('/')
- subnet[1] = int(subnet[1])
- start = struct.unpack('!I', socket.inet_aton(subnet[0]))[0] & \
- int(subnet[1] * '1' + (32 - subnet[1]) * '0', 2)
- finish = start + int((32 - subnet[1]) * '1', 2)
- ip = struct.unpack('!I', socket.inet_aton(ip))[0]
- return (ip > start - 1) and (ip < finish + 1)
- # Returns True if passed parameters matches with dns_rules
- def matches_dns_rules(packet_dir, packet, protocol, port):
- if packet_dir != PKT_DIR_OUTGOING or \
- protocol != 'udp' or port != 53:
- return False
- # 12 = length of udp header (8) + length before QDCOUNT entry in DNS Message (4)
- i = get_ip_header_len(packet)
- i += 12
- try: qdcount = struct.unpack('!H', packet[i:i + 2])[0]
- except: return False
- if qdcount != 1:
- return False
- # length of DNS Message Header (12) -
- # length from transport layer before QDCOUNT entry in DNS Message (4)
- i += 8
- qname = ''
- while True:
- try:
- if packet[i] == '\x00':
- break
- except:
- return False
- offset = struct.unpack('!B', packet[i])[0]
- qname += packet[i + 1:i + 1 + offset] + '.'
- i += offset + 1
- qname = qname[:-1]
- # Jump to QTYPE Section
- i += 1
- try:
- qtype = struct.unpack('!H', packet[i:i + 2])[0]
- except:
- return False
- if qtype not in [1, 28]:
- return False
- # Jump to QCLASS Section
- i += 2
- try:
- qclass = struct.unpack('!H', packet[i:i + 2])[0]
- except:
- return False
- if qclass != 1:
- return False
- return qname
- # Returns True if child is subdomain of parent
- def is_subdomain(child, parent):
- if parent == '' and child != '': # Edge case
- return False
- i = get_common_suffix_len(child, parent)
- if i != 0: parent = parent[:-i]
- return parent in ['', '*']
- # Returns max length of common suffix
- def get_common_suffix_len(str1, str2):
- cmp_fn = lambda x: len(x)
- s1 = min([str1, str2], key = cmp_fn)
- s2 = max([str2, str1], key = cmp_fn)
- i = -1
- while -i != len(s1) + 1 and \
- s1[i] == s2[i]:
- i -= 1
- return -(i + 1)
- # Returns dns response for denied dns request
- def get_dns_response(packet):
- # De-encapsulate dns message
- ip_header_len = get_ip_header_len(packet)
- dns_response = packet[ip_header_len + 8:]
- i = 12
- while dns_response[i] != '\x00':
- i += struct.unpack('!B', dns_response[i])[0] + 1
- i += 1
- qtype = struct.unpack('!H', dns_response[i:i + 2])[0]
- if qtype == 28:
- return None
- second = struct.pack('!B', (struct.unpack('!B', dns_response[2])[0] & 127) + 128)
- dns_response = dns_response[0:2] + second + dns_response[3] + \
- '\x00\x01\x00\x01\x00\x00\x00\x00' + dns_response[12: i + 4] + \
- dns_response[12: i] + \
- '\x00\x01\x00\x01\x00\x00\x00\x01\x00\x04\xa9\xe5\x31\x82'
- # Construct udp header
- old_ip_header = packet[:ip_header_len]
- old_udp_header = packet[ip_header_len: ip_header_len + 8]
- udp_header = construct_udp_header(old_ip_header, old_udp_header, dns_response)
- # Construct segment
- segment = udp_header + dns_response
- # Construct ip header
- ip_header = construct_ip_header(old_ip_header, len(segment))
- # Construct new packet
- new_packet = ip_header + segment
- return new_packet
- # Costructs udp header depending on old one
- def construct_udp_header(old_ip_header, old_udp_header, message):
- # 2) De-encapsulate old udp header
- src_port = old_udp_header[2:4]
- dst_port = old_udp_header[0:2]
- length = 8 + len(message)
- # Construct new udp header
- udp_header = src_port + dst_port + struct.pack('!H', length) + '\x00\x00'
- pseudo_header = get_pseudo_header(old_ip_header, length)
- checksum = get_checksum(pseudo_header + udp_header + message)
- udp_header = udp_header[:6] + checksum
- return udp_header
- # Returns pseudo header for calculating checksum
- def get_pseudo_header(old_ip_header, segment_len):
- src_ip = old_ip_header[16:20]
- dst_ip = old_ip_header[12:16]
- protocol = old_ip_header[9]
- segment_len = struct.pack('!H', segment_len)
- pseudo_header = src_ip + dst_ip + '\x00' + protocol + segment_len
- return pseudo_header
- # Costructs ip header depending on old one
- def construct_ip_header(old_ip_header, segment_len):
- src_ip = old_ip_header[16:20]
- dst_ip = old_ip_header[12:16]
- protocol = old_ip_header[9]
- # Construct new IP header
- length = 20 + segment_len
- length = struct.pack('!H', length)
- ip_header = '\x45\x00' + length + \
- '\x11\x11\x00\x00\x01' + protocol + '\x00\x00' + src_ip + dst_ip
- checksum = get_checksum(ip_header)
- ip_header = ip_header[:10] + checksum + ip_header[12:]
- return ip_header
- # Returns appropriate tcp RST packet
- def get_rst_packet(packet):
- old_ip_header_len = get_ip_header_len(packet)
- old_ip_header = packet[0: old_ip_header_len]
- old_tcp_header = packet[old_ip_header_len:]
- tcp_header = get_rst_tcp_header(old_ip_header, old_tcp_header)
- ip_header = construct_ip_header(old_ip_header, len(tcp_header))
- new_packet = ip_header + tcp_header
- return new_packet
- # Returns RST tcp header
- def get_rst_tcp_header(old_ip_header, old_tcp_header):
- # De-encapsulate old_tcp_header
- src_port = old_tcp_header[2:4]
- dst_port = old_tcp_header[0:2]
- seq_num = old_tcp_header[8:12]
- ack = old_tcp_header[4:8]
- tcp_header = src_port + dst_port + seq_num + ack + \
- '\x50\x14\x00\x00\x00\x00\x00\x00'
- # Construct pseudo header
- pseudo_header = get_pseudo_header(old_ip_header, len(tcp_header))
- checksum = get_checksum(pseudo_header + tcp_header)
- tcp_header = tcp_header[:16] + checksum + tcp_header[18:]
- return tcp_header
- # Logs if needed
- def try_to_log(rule_host_name, protocol, external_ip, external_port, packet_dir, packet):
- if packet_dir != PKT_DIR_OUTGOING or \
- protocol != 'tcp' or external_port != 80:
- return
- start = packet.find('Host')
- if start == -1:
- host_name = external_ip
- matches = (rule_host_name == host_name)
- else:
- finish = packet.find('\r\n', start)
- host_name = packet[start + 6: finish]
- matches = is_subdomain(host_name, rule_host_name)
- if not matches:
- return
- start, method = packet.find('GET'), 'GET'
- if start == -1:
- for m in ['POST', 'HEAD']:
- start = packet.find(m)
- if start != -1:
- method = m
- break
- if start == -1: path = '/'
- else:
- start += len(method) + 1
- finish = packet.find(' ', start)
- path = packet[start: finish]
- version = 'HTTP/1.1'
- status_code = '200'
- object_size = str(len(packet) - 40)
- data = [host_name, method, path, version, status_code, object_size]
- to_log = ' '.join(data) + '\r\n'
- f = open('http.log', 'a')
- f.write(to_log)
- f.flush()
- f.close()
- # Returns checksum of passed data
- def get_checksum(data):
- checksum = '0'
- while len(data) > 1:
- x = data[:2]
- x0 = bin(ord(x[0]))[2:]
- x1 = bin(ord(x[1]))[2:]
- if len(x1) < 8:
- x1 = (8 - len(x1)) * '0' + x1
- x = x0 + x1
- checksum = add(checksum, x)
- data = data[2:]
- if len(data) == 1:
- x = bin(ord(data[0]))[2:] + 8 * '0'
- checksum = add(checksum, x)
- while len(checksum) > 16:
- left = checksum[:len(checksum) - 16]
- right = checksum[len(checksum) - 16:]
- checksum = add(left, right)
- if len(checksum) < 16:
- checksum = (16 - len(checksum)) * '0' + checksum
- reverse = ''
- for i in range(16):
- if checksum[i] == '0':
- reverse += '1'
- else: # if checksum[i] == 1
- reverse += '0'
- checksum = reverse
- checksum = int(checksum, 2)
- checksum = struct.pack('!H', checksum)
- return checksum
- # Adds to binary strings
- def add(a, b):
- return bin(int(a, 2) + int(b, 2))[2:]
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement