Advertisement
opexxx

dnsbulkmongo.py

Jun 6th, 2014
332
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 13.73 KB | None | 0 0
  1. #!/usr/bin/env python
  2. #
  3. # The demo will analyze the DNS records of the top 1 million web sites
  4. #
  5. # Before run the demo:
  6. #   1. Download and install mongodb and pymongo
  7. #
  8. #       http://www.mongodb.org/downloads
  9. #
  10. #       $apt-get install mongodb
  11. #
  12. #       http://api.mongodb.org/python/1.8.1%2B/installation.html
  13. #
  14. #       $easy_install -U pymongo
  15. #
  16. #   2. Download and unpack the top 1M sites from Alexa
  17. #
  18. #       $wget http://s3.amazonaws.com/alexa-static/top-1m.csv.zip
  19. #       $unzip top-1m.csv.zip
  20. #
  21. from __future__ import with_statement
  22.  
  23. import sys
  24. import os, os.path
  25. import logging
  26. import threading
  27. import zipfile
  28. import csv
  29. import time
  30. from datetime import datetime
  31.  
  32. try:
  33.     from cStringIO import StringIO
  34. except ImportError:
  35.     from StringIO import StringIO
  36.  
  37. import dns.name
  38. import dns.rdatatype
  39.  
  40. import pymongo
  41. import asyncdns
  42.  
  43. DEFAULT_MONGO_HOST = "localhost"
  44. DEFAULT_MONGO_PORT = 27017
  45. DEFAULT_DATABASE_NAME = "alexa"
  46. DEFAULT_DNS_SERVERS = asyncdns.Resolver.system_nameservers()
  47. DEFAULT_DNS_TIMEOUT = 30
  48. DEFAULT_CONCURRENCY = 20
  49.  
  50. def parse_cmdline():
  51.     from optparse import OptionParser
  52.  
  53.     parser = OptionParser(usage="usage: %prog [options] <files>")
  54.  
  55.     parser.add_option("--mongo-host", dest="mongo_host", default=DEFAULT_MONGO_HOST,
  56.                       metavar="HOST", help="mongodb host to connect to (default: %s)" % DEFAULT_MONGO_HOST)
  57.     parser.add_option("--mongo-port", dest="mongo_port", default=27017, type="int",
  58.                       metavar="PORT", help="mongodb port to connect to (default: %d)" % DEFAULT_MONGO_PORT)
  59.     parser.add_option("--db-name", dest="db_name", default=DEFAULT_DATABASE_NAME,
  60.                       metavar="NAME", help="mongodb database to open (default: %s)" % DEFAULT_DATABASE_NAME)
  61.  
  62.     parser.add_option("--dns-host", dest="dns_hosts", action="append", default=None,
  63.                       metavar="HOST", help="DNS server to query (default: %s)" % ', '.join(DEFAULT_DNS_SERVERS))
  64.     parser.add_option("-t", "--dns-timeout", dest="dns_timeout", default=DEFAULT_DNS_TIMEOUT, type="int",
  65.                       metavar="NUM", help="DNS query timeout in seconds (default: %d)" % DEFAULT_DNS_TIMEOUT)
  66.  
  67.     parser.add_option("--force-update", dest="force_update", default=False, action="store_true",
  68.                       help="force to update the exist domains")
  69.  
  70.     parser.add_option("-c", "--concurrency", default=DEFAULT_CONCURRENCY, type="int",
  71.                       metavar="NUM", help="Number of multiple queries to make (default: %d)" % DEFAULT_CONCURRENCY)
  72.  
  73.     parser.add_option("-v", "--verbose", action="store_const",
  74.                       const=logging.INFO, dest="log_level", default=logging.WARN)
  75.     parser.add_option("-d", "--debug", action="store_const",
  76.                       const=logging.DEBUG, dest="log_level")
  77.     parser.add_option("--log-format", dest="log_format",
  78.                       metavar="FMT", default="%(asctime)s %(levelname)s %(message)s")
  79.     parser.add_option("--log-file", dest="log_file", metavar="FILE")
  80.  
  81.     opts, args = parser.parse_args()
  82.  
  83.     return opts, args
  84.  
  85. class Updater(object):
  86.     logger = logging.getLogger("updater")
  87.  
  88.     def __init__(self, max_currency=20):
  89.         self.lock = threading.Semaphore(max_currency)
  90.  
  91.     def connect(self, host, port, dbname):
  92.         try:
  93.             conn = pymongo.Connection(host, port)
  94.         except pymongo.errors.AutoReconnect:
  95.             self.logger.warn("fail to connect mongodb @ %s:%d", host, port)
  96.  
  97.             return None
  98.  
  99.         self.logger.info("connected to mongodb @ %s:%d [%s]", conn.host, conn.port,
  100.                          ','.join(["%s: %s" % (k, v) for k, v in conn.server_info().items()]))
  101.  
  102.         self.db = conn[dbname]
  103.  
  104.         self.prepare(self.db)
  105.  
  106.         return conn
  107.  
  108.     def prepare(self, db):
  109.         if 'domains' not in db.collection_names():
  110.             self.logger.info("initialize the `domains` collection and indexes")
  111.  
  112.             db.domains.create_index([("domain", pymongo.ASCENDING)], unique=True)
  113.             db.domains.create_index([("alexa", pymongo.ASCENDING)])
  114.             db.domains.create_index([("ts", pymongo.DESCENDING)])
  115.  
  116.             db.domains.insert({
  117.                 "domain": ".",
  118.                 "ns" : [chr(ch) + '.root-servers.net' for ch in range(ord('a'), ord('m'))]
  119.             })
  120.         else:
  121.             self.logger.info("found the `domains` collection")
  122.  
  123.             db.domains.ensure_index([("domain", pymongo.ASCENDING)], unique=True)
  124.             db.domains.ensure_index([("alexa", pymongo.ASCENDING)])
  125.             db.domains.ensure_index([("ts", pymongo.DESCENDING)])
  126.  
  127.             db.domains.update({"domain": "."}, {
  128.                 "$addToSet" : {
  129.                     "ns" : {
  130.                         "$each": [chr(ch) + '.root-servers.net' for ch in range(ord('a'), ord('m'))]
  131.                     }
  132.                 }
  133.             })
  134.  
  135.     def load(self, filename):
  136.         self.logger.info("loading records from file %s", filename)
  137.  
  138.         if zipfile.is_zipfile(filename):
  139.             zip = zipfile.ZipFile(filename, 'r')
  140.             try:
  141.                 for name in zip.namelist():
  142.                     for row in csv.reader(StringIO(zip.read(name))):
  143.                         yield int(row[0]), row[1]
  144.             finally:
  145.                 zip.close()
  146.         else:
  147.             with open(filename, 'r') as f:
  148.                 for row in csv.reader(f):
  149.                     yield int(row[0]), row[1]
  150.  
  151.     def insert(self, records, update):
  152.         count = updated = 0
  153.  
  154.         domains = []
  155.  
  156.         for alexa, domain in records:
  157.             if update:
  158.                 record = self.db.domains.find_one({"domain": domain})
  159.  
  160.                 if record:
  161.                     record["alexa"] = alexa
  162.                     record["ts"] = datetime.utcnow()
  163.  
  164.                     self.db.domains.save(record)
  165.  
  166.                     updated += 1
  167.  
  168.                     if updated % 1000 == 0:
  169.                         self.logger.info("updated 1K records till %d", updated)
  170.  
  171.                     continue
  172.  
  173.             domains.append({
  174.                 "domain": domain,
  175.                 "alexa": alexa,
  176.                 "ts": datetime.utcnow()
  177.             })
  178.  
  179.             count += 1
  180.  
  181.             if len(domains) == 10000:
  182.                 self.batch_insert(count, domains)
  183.  
  184.                 domains = []
  185.  
  186.         self.batch_insert(count, domains)
  187.  
  188.     def batch_insert(self, pos, domains):
  189.         if domains:
  190.             start = time.clock()
  191.  
  192.             self.db.domains.insert(domains)
  193.  
  194.             self.logger.info("inserted 10K records till %sK in %f seconds",
  195.                              pos/1000, time.clock() - start)
  196.  
  197.     def run(self, resolver, nameservers, timeout):
  198.         self.queryLocalNameserver(resolver, nameservers, timeout)
  199.         #self.queryAuthoritativeNameserver(resolver, timeout)
  200.  
  201.     def queryLocalNameserver(self, resolver, nameservers, timeout):
  202.         cursor = self.db.domains.find({
  203.             'domain': {'$exists': True},
  204.             '$or': [
  205.                 { 'ip': {'$exists': False} },
  206.                 { 'ns': {'$exists': False} },
  207.                 { 'alias': {'$exists': False} },
  208.             ]
  209.         })
  210.  
  211.         self.logger.info("query %d domain from the local nameservers", cursor.count())
  212.  
  213.         if nameservers is None:
  214.             nameservers = DEFAULT_DNS_SERVERS
  215.  
  216.         latch = asyncdns.CountDownLatch(cursor.count()*len(nameservers))
  217.  
  218.         def onfinish(nameserver, domain, results):
  219.             self.lock.release()
  220.  
  221.             self.update(results)
  222.  
  223.             latch.countDown()
  224.  
  225.         for record in cursor:
  226.             self.lock.acquire()
  227.  
  228.             try:
  229.                 resolver.lookupAllRecords(record['domain'], expired=timeout,
  230.                                           callback=onfinish, nameservers=nameservers)
  231.             except Exception, e:
  232.                 self.logger.warn("fail to query domain: %s, %s", record['domain'], e)
  233.  
  234.         latch.await()
  235.  
  236.     def queryAuthoritativeNameservers(self, resolver, timeout):
  237.         self.logger.info("query")
  238.  
  239.         cursor = self.db.domains.find({
  240.             'domain': {'$exists': True},
  241.             'ns': {'$exists': False},
  242.         })
  243.  
  244.         self.logger.info("query %d domain from the authoritative nameservers", cursor.count())
  245.  
  246.         latch = asyncdns.CountDownLatch(cursor.count())
  247.  
  248.         def onfinish():
  249.             self.lock.release()
  250.  
  251.             latch.countDown()
  252.  
  253.         for record in cursor:
  254.             self.lock.acquire()
  255.  
  256.             try:
  257.                 resolver.lookupScene(self.queryAuthoritativeNameserver(record['domain']),
  258.                                      callback=onfinish)
  259.             except Exception, e:
  260.                 self.logger.warn("fail to query domain: %s, %s", record['domain'], e)
  261.  
  262.         latch.await()
  263.  
  264.     @asyncdns.Scene()
  265.     def queryAuthoritativeNameserver(self, domain):
  266.         qname = dns.name.from_text(domain)
  267.  
  268.         domains = ['.'.join(qname[i:]) for i in range(len(qname))]
  269.  
  270.         nameservers = None
  271.  
  272.         while len(domains) > 1:
  273.             domain = domains.pop()
  274.             domain = '.' if domain == '' else domain.strip('.')
  275.  
  276.             record = self.db.domains.find({'domain': domain}, ['ns'])
  277.  
  278.             if record is None:
  279.                 nameserver, results = yield async.scene.Query(domain, dns.rdatatype.NS,
  280.                                                               nameservers=nameservers)
  281.  
  282.                 record = {
  283.                     'domain': domain,
  284.                     'ns' : results[domain][dns.rdatatype.NS]
  285.                 }
  286.  
  287.                 record['_id'] = self.db.domains.insert(record)
  288.  
  289.             nameservers = record['ns']
  290.             nameservers = nameservers[:min(3, len(nameservers))]
  291.  
  292.         yield async.scene.Finished
  293.  
  294.     DNS_FIELDNAME_MAPPING = {
  295.         'A': 'ip',
  296.         'AAAA': 'ipv6',
  297.         'NS': 'ns',
  298.         'CNAME': 'alias',
  299.         'TXT': 'text',
  300.     }
  301.  
  302.     def update(self, results):
  303.         if isinstance(results, Exception):
  304.             self.logger.warn("fail to query domain, %s", results)
  305.             return
  306.  
  307.         for domain, records in results.items():
  308.             self.logger.info("received result for %s", domain)
  309.  
  310.             if self.db.domains.find({"domain": domain}).count() == 0:
  311.                 self.db.domains.insert({
  312.                     "domain": domain,
  313.                     "ts": datetime.utcnow()
  314.                 })
  315.  
  316.             data = {}
  317.  
  318.             for rdtype, values in records.items():
  319.                 if rdtype in ['A', 'AAAA', 'NS', 'CNAME', 'TXT']:
  320.                     data.setdefault("$addToSet", {})[self.DNS_FIELDNAME_MAPPING[rdtype]] = {
  321.                         "$each": values
  322.                     }
  323.                 elif rdtype == 'MX':
  324.                     data.setdefault("$addToSet", {})['mail'] = {
  325.                         "$each": [{
  326.                             "exchange": exchange,
  327.                             "preference": preference
  328.                         } for exchange, preference in values]
  329.                     }
  330.                 elif rdtype == 'SOA':
  331.                     for mname, rname, serial, refresh, retry, expire, minimum in values:
  332.                         data.setdefault("$set", {})["soa"] = {
  333.                             "mname": mname,
  334.                             "rname": rname,
  335.                             "serial": serial,
  336.                             "refresh": refresh,
  337.                             "retry": retry,
  338.                             "expire": expire,
  339.                             "minimum": minimum
  340.                         }
  341.                 elif rdtype == 'WKS':
  342.                     data.setdefault("$addToSet", {})['service'] = {
  343.                         "$each": [{
  344.                             "address": address,
  345.                             "protocol": protocol,
  346.                             "bitmap": bitmap
  347.                         } for address, protocol, bitmap in values]
  348.                     }
  349.                 elif rdtype == 'SRV':
  350.                     data.setdefault("$addToSet", {})['server'] = {
  351.                         "$each": [{
  352.                             "target": target,
  353.                             "port": port,
  354.                             "priority": priority,
  355.                             "weight": weight
  356.                         } for target, port, priority, weight in values]
  357.                     }
  358.                 else:
  359.                     self.logger.warn("drop domain %s unknown %s type: %s", domain, rdtype, values)
  360.  
  361.             self.db.domains.update({"domain": domain}, data)
  362.  
  363. if __name__=='__main__':
  364.     opts, args = parse_cmdline()
  365.  
  366.     logging.basicConfig(level=opts.log_level,
  367.                         format=opts.log_format,
  368.                         filename=opts.log_file,
  369.                         stream=sys.stdout)
  370.  
  371.     updater = Updater(opts.concurrency)
  372.  
  373.     if updater.connect(opts.mongo_host, opts.mongo_port, opts.db_name):
  374.         for arg in args:
  375.             if os.path.isfile(arg):
  376.                 updater.insert(updater.load(arg), opts.force_update)
  377.             else:
  378.                 print "WARN: ignore invalid argument:", arg
  379.  
  380.         wheel = asyncdns.TimeWheel()
  381.         resolver = asyncdns.Resolver(wheel)
  382.  
  383.         updater.run(resolver, opts.dns_hosts, opts.dns_timeout)
  384.     else:
  385.         print "ERROR: Fail to connect mongodb at %s:%d" % (opts.mongo_host, opts.mongo_port)
  386.         print
  387.         print "Please set the host and port with parameters, like"
  388.         print
  389.         print "     %s --mongo-host=<host> --mongo-port=<port> [options] <args>" % os.path.basename(sys.argv[0])
  390.         print
  391.         print "Or download and install mongodb from the offical site"
  392.         print
  393.         print "     http://www.mongodb.org/downloads"
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement