Advertisement
opexxx

dnsredir.py

May 18th, 2014
316
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 14.57 KB | None | 0 0
  1. #!/usr/bin/env python
  2. """
  3. A small DNS server that answers a small set of queries
  4. and proxies the rest through a 'real' DNS server.
  5.  
  6. See the documentation for more details.
  7.  
  8. NOTES:
  9.  - no attempt is made to make IDs unguessable.  This is a security
  10.    weakness that can be exploited in a hostile enviornment.
  11.  - will complain about slack data
  12.  
  13. TODO:
  14.  - more record types if needed: PTR, TXT, ...
  15. """
  16.  
  17. import optparse, re, socket, struct, time
  18.  
  19. publicDNS = '8.8.8.8' # google's public DNS server
  20. publicDNS6 = '::ffff:' + publicDNS
  21. gQuiet = False
  22.  
  23. QUERY,IQUERY = 0,1
  24. IN = 1
  25. A,NS,CNAME,PTR,MX,TXT,AAAA = 1,2,5,12,15,16,28
  26. LABLEN,LABOFF = 0,3
  27.  
  28. class Error(Exception) :
  29.     pass
  30.  
  31. def log(fmt, *args) :
  32.     if not gQuiet :
  33.         now = time.time()
  34.         ts = time.strftime('%Y-%m-%d:%H:%M:%S', time.localtime(now))
  35.         print ts, fmt % args
  36.  
  37. def getBits(num, *szs) :
  38.     """Get bits from num from right to left."""
  39.     rs = []
  40.     for sz in szs :
  41.         mask = (1 << sz) - 1
  42.         rs.append(num & mask)
  43.         num >>= sz
  44.     return rs
  45. def putBits(szs, *args) :
  46.     """Put bits into an integer from right to left."""
  47.     num = 0
  48.     sh = 0
  49.     for (sz,a) in zip(szs, args) :
  50.         mask = (1 << sz) - 1
  51.         num |= ((a & mask) << sh)
  52.         sh += sz
  53.     return num
  54.  
  55. def getPacked(fmt, buf, off) :
  56.     sz = struct.calcsize(fmt)
  57.     return struct.unpack(fmt, buf[off : off + sz]), off+sz
  58. def putPacked(buf, fmt, *args) :
  59.     buf.append(struct.pack(fmt, *args))
  60.  
  61. def getLabel(buf, off, ctx) :
  62.     """Get a DNS label, without any decompression."""
  63.     (b,),dummy = getPacked('!B', buf, off)
  64.     typ = b >> 6
  65.     if typ == LABLEN :
  66.         endOff = off + 1 + b
  67.         name = buf[off + 1 : endOff]
  68.         val = name
  69.     elif typ == LABOFF :
  70.         (ptr,),endOff = getPacked('!H', buf, off)
  71.         ptr &= 0x3fff
  72.         val = ptr
  73.     else :
  74.         raise Error("invalid label type %d at %d" % (typ, off))
  75.     return (typ, val, off), endOff
  76.  
  77. def getDomName(buf, off, ctx) :
  78.     """Get a domain name, performing decompression."""
  79.     idx = off
  80.     labs = []
  81.     endOff = off
  82.     while True :
  83.         (t,v,o),endOff = getLabel(buf, endOff, ctx)
  84.         labs.append((t,v,o))
  85.         if t == LABOFF or v == '' : # terminate at pointer or empty name
  86.             break
  87.  
  88.     if idx not in ctx : # decompress
  89.         ctx[idx] = None # Avoid loops during recursion. see below.
  90.         r = []
  91.         for t,v,o in labs :
  92.             if t == LABLEN :
  93.                 r.append(v)
  94.             else :
  95.                 name,dummy = getDomName(buf, v, ctx)
  96.                 r.append(name)
  97.         ctx[idx] = '.'.join(r)
  98.     if ctx[idx] is None :
  99.         raise Error("invalid loop in domain decompression at %d" % o)
  100.     return ctx[idx],endOff
  101.  
  102. def putDomain(buf, dom) :
  103.     """Put a domain name. Never compressed..."""
  104.     labs = dom.rstrip('.').split('.')
  105.     if len(dom) > 255 or any(len(l) > 63 or len(l) == 0 for l in labs) :
  106.         raise Error("Cannot encode domain: %s" % dom)
  107.     labs.append('') # terminator
  108.     for l in labs :
  109.         putPacked(buf, "!B", len(l))
  110.         buf.append(l)
  111.  
  112. class DNSQuestion(object) :
  113.     def get(self, buf, off, ctx) :
  114.         self.name,off = getDomName(buf, off, ctx)
  115.         (self.type,self.klass),off = getPacked("!HH", buf, off)
  116.         return off
  117.     def put(self, buf) :
  118.         putDomain(buf, self.name)
  119.         putPacked(buf, "!HH", self.type, self.klass)
  120.     def __str__(self) :
  121.         return '[Q %s %s %s]' % (self.name, self.type, self.klass)
  122.  
  123. class DNSResA(object) :
  124.     def __init__(self, val=None) :
  125.         if val is not None :
  126.             self.val = val
  127.     def get(self, buf, off) :
  128.         self.val = mkIPv4(buf[off : off+4])
  129.         return off+4
  130.     def put(self, buf) :
  131.         buf.append(parseIPv4(self.val))
  132.     def __str__(self) :
  133.         return '[A %s]' % (self.val)
  134.  
  135. class DNSResAAAA(object) :
  136.     def __init__(self, val=None) :
  137.         if val is not None :
  138.             self.val = val
  139.     def get(self, buf, off) :
  140.         self.val = mkIPv6(buf[off : off+16])
  141.         return off+16
  142.     def put(self, buf) :
  143.         buf.append(parseIPv6(self.val))
  144.     def __str__(self) :
  145.         return '[AAAA %s]' % (self.val)
  146.  
  147. class DNSResRec(object) :
  148.     children = {
  149.         A:      DNSResA,
  150.         AAAA:   DNSResAAAA,
  151.         #CNAME:  DNSResCName,
  152.         #MX:     DNSResMx,
  153.         #NS:     DNSResNs,
  154.         #PTR:    DNSResPtr,
  155.         #TXT:    DNSResTxt,
  156.     }
  157.     def get(self, buf, off, ctx) :
  158.         self.name,off = getDomName(buf, off, ctx)
  159.         (self.type,self.klass,self.ttl, l),off = getPacked("!HHIH", buf, off)
  160.         self.nested = buf[off : off + l]
  161.         off += l
  162.  
  163.         self.val = None
  164.         if self.type in self.children :
  165.             self.val = self.children[self.type]()
  166.             n = self.val.get(self.nested, 0)
  167.             if n != len(self.nested) :
  168.                 raise Error("unexpected nested slack data: %r" % self.nested[n:])
  169.         return off
  170.  
  171.     def put(self, buf) :
  172.         if self.val is not None :
  173.             buf2 = []
  174.             self.val.put(buf2)
  175.             self.nested = ''.join(buf2)
  176.  
  177.         putDomain(buf, self.name)
  178.         l = len(self.nested)
  179.         putPacked(buf, "!HHIH", self.type, self.klass, self.ttl, l)
  180.         buf.append(self.nested)
  181.  
  182.     def __str__(self) :
  183.         v = repr(self.nested)
  184.         if self.val :
  185.             v = self.val
  186.         return '[RR %s %s %s %s %s]' % (self.type, self.klass, self.ttl, self.name, v)
  187.  
  188. def getArray(buf, off, cnt, constr, ctx) :
  189.     objs = []
  190.     for n in xrange(cnt) :
  191.         obj = constr()
  192.         objs.append(obj)
  193.         off = obj.get(buf, off, ctx)
  194.     return objs, off
  195. def putArray(buf, arr) :
  196.     for obj in arr :
  197.         obj.put(buf)
  198. def arrStr(xs) :
  199.     return '[%s]' % (', '.join(str(x) for x in xs))
  200.  
  201. class DNSMsg(object) :
  202.     def __init__(self, buf=None) :
  203.         self.id = 0
  204.         self.rcode, self.z, self.ra, self.rd, self.tc, self.aa, self.opcode, self.qr = 0, 0, 0, 0, 0, 0, 0, 0
  205.         self.qd, self.an, self.ns, self.ar = [],[],[],[]
  206.  
  207.         if buf is not None :
  208.             self.get(buf)
  209.  
  210.     def get(self, buf) :
  211.         ctx = {}
  212.         (self.id, bits, qdcount, ancount, nscount, arcount),n = getPacked("!HHHHHH", buf, 0)
  213.         self.rcode, self.z, self.ra, self.rd, self.tc, self.aa, self.opcode, self.qr = getBits(bits, 4, 3, 1, 1, 1, 1, 4, 1)
  214.         self.qd,n = getArray(buf, n, qdcount, DNSQuestion, ctx)
  215.         self.an,n = getArray(buf, n, ancount, DNSResRec, ctx)
  216.         self.ns,n = getArray(buf, n, nscount, DNSResRec, ctx)
  217.         self.ar,n = getArray(buf, n, arcount, DNSResRec, ctx)
  218.         if n < len(buf) :
  219.             raise Error("unexpected slack data: %r" % buf[n:])
  220.  
  221.     def put(self) :
  222.         buf = []
  223.         bits = putBits((4, 3, 1, 1, 1, 1, 4, 1), self.rcode, self.z, self.ra, self.rd, self.tc, self.aa, self.opcode, self.qr)
  224.         putPacked(buf, "!HHHHHH", self.id, bits, len(self.qd), len(self.an), len(self.ns), len(self.ar))
  225.         putArray(buf, self.qd)
  226.         putArray(buf, self.an)
  227.         putArray(buf, self.ns)
  228.         putArray(buf, self.ar)
  229.         bytes = ''.join(buf)
  230.         if len(bytes) > 64*1024 :
  231.             raise Error("Response is too big: %d!" % len(bytes))
  232.         return bytes
  233.  
  234.     def __str__(self) :
  235.         arrs = 'qd=%s an=%s ns=%s ar=%s' % tuple(arrStr(x) for x in (self.qd,self.an,self.ns,self.ar))
  236.         return '[DNSMsg id=%d rcode=%d z=%d ra=%d rd=%d tc=%d aa=%d opcode=%d qr=%d %s]' % (self.id, self.rcode, self.z, self.ra, self.rd, self.tc, self.aa, self.opcode, self.qr, arrs)
  237.  
  238. def lookup(db, ty, name) :
  239.     for ty_,pat,val in db :
  240.         if ty == ty_ and re.match(pat, name) :
  241.             return val
  242.  
  243. def mkResp(q, val, ttl, id, opcode) :
  244.     a = DNSResRec()
  245.     a.name, a.type, a.klass = q.name, q.type, q.klass
  246.     a.ttl, a.val = ttl, val
  247.    
  248.     resp = DNSMsg()
  249.     resp.id = id
  250.     resp.qr = 1
  251.     resp.opcode = opcode
  252.     resp.qd = [q]
  253.     resp.an = [a]
  254.     return resp
  255.  
  256. def procQuery(opt, s, m, peer) :
  257.     resp = None
  258.     if m.opcode == QUERY and len(m.qd) == 1 :
  259.         q = m.qd[0]
  260.         log("Simple query from %s class=%d type=%d name=%r", peer, q.klass, q.type, q.name)
  261.         if q.klass == IN and q.type == A :
  262.             ip = lookup(opt.names, 'A', q.name)
  263.             if ip :
  264.                 log("Answering %s/%d query IN A %s with %s", peer, m.id, q.name, ip)
  265.                 resp = mkResp(q, DNSResA(ip), opt.ttl, m.id, m.opcode)
  266.         if q.klass == IN and q.type == AAAA :
  267.             ip = lookup(opt.names, 'AAAA', q.name)
  268.             if ip :
  269.                 log("Answering %s/%d query IN AAAA %s with %s", peer, m.id, q.name, ip)
  270.                 resp = mkResp(q, DNSResAAAA(ip), opt.ttl, m.id, m.opcode)
  271.     return resp
  272.  
  273. class Proxy(object) :
  274.     """Proxy objects and the global proxy table."""
  275.     timeo = 30
  276.     id = 1
  277.     tab = {}
  278.  
  279.     @staticmethod
  280.     def clean() :
  281.         self = Proxy
  282.         now = time.time()
  283.         for k,v in self.tab.items() :
  284.             if v.expire <= now :
  285.                 log("expire proxy request %d", k)
  286.                 del self.tab[k]
  287.  
  288.     def __init__(self, peer, msg) :
  289.         """Make a proxy object and put it in the prox table."""
  290.         self.expire = time.time() + self.timeo
  291.         self.peer = peer
  292.         self.origId = msg.id
  293.         self.id = Proxy.id
  294.  
  295.         if self.id in self.tab :
  296.             # should only happen in hostile situations or under heavy loads
  297.             raise Error("proxy ID collision!")
  298.         self.tab[self.id] = self
  299.  
  300.         Proxy.id = (Proxy.id + 1) & 0xffff # weak ID generation
  301.  
  302.  
  303. def sendMsg(s, addr, msg) :
  304.     buf = msg.put()
  305.     if s.sendto(buf, addr) != len(buf) :
  306.         raise Error("failure sending msg: " + e)
  307.  
  308. def procMsg(opt, sock, buf, peer) :
  309.     Proxy.clean()
  310.  
  311.     m = DNSMsg()
  312.     try :
  313.         m.get(buf)
  314.     except Error, e :
  315.         log("Error parsing msg from %s: %s", peer, e)
  316.         return
  317.  
  318.     if m.qr == 0 : # query from client - answer it or proxy it
  319.         resp = procQuery(opt, sock, m, peer)
  320.         if resp is not None :
  321.             log("Send answer to %s", peer)
  322.             sendMsg(sock, peer, resp)
  323.         else : # not processed, proxy it
  324.             p = Proxy(peer, m)
  325.             log("Proxy msg from client %s/%d to server %s/%d", peer, m.id, opt.srv, p.id)
  326.             m.id = p.id
  327.             sendMsg(sock, opt.srv, m)
  328.     else : # response from server - proxy back to client
  329.         p = Proxy.tab.get(m.id)
  330.         if p is not None :
  331.             del Proxy.tab[m.id]
  332.             log("Proxy msg from server %s/%d to client %s/%d", peer, m.id, p.peer, p.origId)
  333.             m.id = p.origId
  334.             sendMsg(sock, p.peer, m)
  335.         else :
  336.             log("Unexpected response from %s/%d", peer, m.id)
  337.  
  338. def server(opt) :
  339.     if opt.six :
  340.         s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
  341.     else :
  342.         s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  343.         try :
  344.             # Handle both ipv4 and ipv6.
  345.             # This is on by default on many but not all systems.
  346.             s.setsockopt(socket.IPPROTO_IPV6, IPV6_V6ONLY, 0)
  347.         except Exception, e :
  348.             pass
  349.     s.bind((opt.bindAddr, opt.port))
  350.     while True :
  351.         buf,peer = s.recvfrom(64 * 1024)
  352.         log("Received %d bytes from %s", len(buf), peer)
  353.         try :
  354.             procMsg(opt, s, buf, peer)
  355.         except Error,e :
  356.             log("Error processing from %s", peer)
  357.  
  358. def mkIPv4(xs) :
  359.     return socket.inet_ntoa(xs)
  360. def parseIPv4(s) :
  361.     try :
  362.         return socket.inet_aton(s)
  363.     except :
  364.         raise Error("Bad IP address format: %r" % s)
  365.  
  366. def mkHex16(buf) :
  367.     return '%x' % ((ord(buf[0]) << 8) | ord(buf[1]))
  368. def parseHex16(n) :
  369.     return chr((n >> 8) & 0xff) + chr(n & 0xff)
  370. def mkIPv6(bs) :
  371.     assert len(bs) == 16
  372.     ns = [mkHex16(bs[n:n+2]) for n in xrange(0,16,2)]
  373.     return ':'.join(ns)
  374. def parseIPv6(s) :
  375.     try :
  376.         ws = s.split(':')
  377.         if '.' in ws[-1] : # 32-bit IPv4 instead 16-bit hex
  378.             bs = parseIPv4(ws[-1])
  379.             ws[-1:] = [mkHex16(bs), mkHex16(bs[2:])]
  380.        
  381.         if '' in ws :
  382.             idx = ws.index('')
  383.             n = 8 - (len(ws) - 1)
  384.             if n > 0 : # expand at idx to full width
  385.                 ws[idx : idx+1] = ['0'] * n
  386.  
  387.             while '' in ws : # all others become zeros without expansion
  388.                 ws[ws.index('')] = '0'
  389.  
  390.         if len(ws) != 8 :
  391.             raise Error("wrong length") # jump to err
  392.         try :
  393.             ns = [int(w, 16) for w in ws]
  394.         except ValueError, e :
  395.             raise Error("bad hex")
  396.  
  397.         if any(n < 0 or n > 0xffff for n in ns) :
  398.             raise Error("bad value")
  399.  
  400.         return ''.join(parseHex16(n) for n in ns)
  401.     except Error, e :
  402.         raise Error("Invalid IPv6 address: %r" % s)
  403.  
  404. def parseNames(args) :
  405.     tab = []
  406.     for a in args :
  407.         if ':' not in a :
  408.             raise Error("Argument must be type:name=value -- %r" % a)
  409.         ty,rest = a.split(':', 1)
  410.         if '=' not in rest :
  411.             raise Error("Argument must be type:name=value -- %r" % a)
  412.         nm,val = rest.split('=', 1)
  413.  
  414.         pat = '^' + nm + '$' # anchor regexp
  415.         if ty == 'A' :
  416.             dummy = parseIPv4(val)
  417.         elif ty == 'AAAA' :
  418.             dummy = parseIPv6(val)
  419.         else :
  420.             raise Error("Unsupported query type %r in %r" % (ty, a))
  421.         tab.append((ty,pat,val))
  422.     return tab
  423.  
  424. def getopts() :
  425.     p = optparse.OptionParser(usage="usage: %prog [opts] [type:name:val ...]")
  426.     p.add_option('-d', dest='dnsServer', default=None, help='default DNS server. Default=' + publicDNS)
  427.     p.add_option('-b', dest='bindAddr', default='', help='Address to bind to. Default=any')
  428.     p.add_option('-p', dest='port', type=int, default=53, help='Port to listen on. Default=53')
  429.     p.add_option('-P', dest='dnsServerPort', type=int, default=53, help='Port of default DNS server. Default=53')
  430.     p.add_option('-t', dest='ttl', type=int, default=30, help='TTL for responses. Default=30 seconds')
  431.     p.add_option('-q', dest='quiet', action='store_true', help='Quiet')
  432.     p.add_option('-6', dest='six', action='store_true', help='Use IPv6 server socket')
  433.     opt,args = p.parse_args()
  434.     opt.names = parseNames(args)
  435.     if opt.dnsServer == None :
  436.         opt.dnsServer = publicDNS6 if opt.six else publicDNS
  437.     opt.srv = opt.dnsServer, opt.dnsServerPort
  438.     global gQuiet
  439.     gQuiet = opt.quiet
  440.     return opt
  441.  
  442. def main() :
  443.     opt = getopts()
  444.     server(opt)
  445.  
  446. if __name__ == '__main__' :
  447.     main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement