Advertisement
Tatantyler

Arbitrary Precision (Unsigned) Integers (BigInt)

Apr 13th, 2013
983
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Lua 10.79 KB | None | 0 0
  1. -- This source file is at version 1 as of the last time I've bothered to update it.
  2. -- KillaVanilla's arbitrary-precision API in Lua. Please don't steal it.
  3.  
  4. -- All arbitrary precision integers (or "bigInts" or any other capitalization of such) are unsigned. Operations where the result is less than 0 are undefined.
  5. -- BigInts are stored as tables with each digit occupying an entry. These tables store values least-significant digit first.
  6. -- For example, the number 1234 in BigInt format would be {4, 3, 2, 1}. This process is automatically done with bigInt.toBigInt().
  7.  
  8. -- Several of these functions have multiple names. For example, bigInt.mod(a,b) can also be called as bigInt.modulo(a,b), and bigInt.cmp_lt(a,b) can be called as bigInt.cmp_less_than(a,b).
  9.  
  10. -- Alternate names:
  11. -- left and right shifts: blshift() and brshift()
  12. -- sub, mul, div, mod, exp: subtract(), multiply(), divide(), modulo(), exponent()
  13. -- <the comparison functions>: cmp_<full name of comparision> (e.g "cmp_greater_than_or_equal_to", "cmp_greater_than_equal_to", or "cmp_gteq")
  14. -- toStr: tostring()
  15. -- bitwise operations (AND, OR, XOR, NOT): band(), bor(), bxor(), bnot().
  16.  
  17. local function round(i) -- round a float
  18.     if i - math.floor(i) >= 0.5 then
  19.         return math.ceil(i)
  20.     end
  21.     return math.floor(i)
  22. end
  23.  
  24. local function copy(input)
  25.     if type(input) == "number" then
  26.         return toBigInt(input)
  27.     end
  28.     local t = {}
  29.     for i,v in pairs(input) do
  30.         t[i] = v
  31.     end
  32.     return t
  33. end
  34.  
  35. function removeTrailingZeroes(a)
  36.     local cpy = copy(a)
  37.     for i=#cpy, 1, -1 do
  38.         if cpy[i] ~= 0 then
  39.             break
  40.         else
  41.             cpy[i] = nil
  42.         end
  43.     end
  44.     return cpy
  45. end
  46.  
  47. cmp_lt = function(a,b) -- Less Than
  48.     local a2 = removeTrailingZeroes(a)
  49.     local b2 = removeTrailingZeroes(b)
  50.    
  51.     if #a2 > #b2 then
  52.         return false
  53.     end
  54.     if #b2 > #a2 then
  55.         return true
  56.     end
  57.    
  58.     for i=#a2, 1, -1 do
  59.         if a2[i] > b2[i] then
  60.             return false
  61.         elseif a2[i] < b2[i] then
  62.             return true
  63.         end
  64.     end
  65.     return false
  66. end
  67.  
  68. cmp_gt = function(a,b) -- Greater Than
  69.     local a2 = removeTrailingZeroes(a)
  70.     local b2 = removeTrailingZeroes(b)
  71.    
  72.     if #a2 < #b2 then
  73.         return false
  74.     end
  75.     if #b2 < #a2 then
  76.         return true
  77.     end
  78.    
  79.     for i=#a2, 1, -1 do
  80.         if a2[i] > b2[i] then
  81.             return true
  82.         elseif a2[i] < b2[i] then
  83.             return false
  84.         end
  85.     end
  86.     return false
  87. end
  88.  
  89. cmp_lteq = function(a,b) -- Less Than or EQual to
  90.     local a2 = removeTrailingZeroes(a)
  91.     local b2 = removeTrailingZeroes(b)
  92.    
  93.     if #a2 > #b2 then
  94.         return false
  95.     end
  96.     if #b2 > #a2 then
  97.         return true
  98.     end
  99.    
  100.     for i=#a2, 1, -1 do
  101.         if a2[i] > b2[i] then
  102.             return false
  103.         elseif a2[i] < b2[i] then
  104.             return true
  105.         end
  106.     end
  107.     return true
  108. end
  109.  
  110. cmp_gteq = function(a,b) --Greater Than or EQual to
  111.     local a2 = removeTrailingZeroes(a)
  112.     local b2 = removeTrailingZeroes(b)
  113.    
  114.     if #a2 < #b2 then
  115.         --print("[debug] GTEQ: a2="..toStr(a2).." b2="..toStr(b2).." #a2="..#a2.." #b2="..#b2.." #a2<#b2")
  116.         return false
  117.     end
  118.     if #b2 < #a2 then
  119.         --print("[debug] GTEQ: a2="..toStr(a2).." b2="..toStr(b2).." #a2="..#a2.." #b2="..#b2.." #b2<#a2")
  120.         return true
  121.     end
  122.    
  123.     for i=#a2, 1, -1 do
  124.         if a2[i] > b2[i] then
  125.             return true
  126.         elseif a2[i] < b2[i] then
  127.             return false
  128.         end
  129.     end
  130.     return true
  131. end
  132.  
  133. cmp_eq = function(a,b) --EQuality
  134.     local a2 = removeTrailingZeroes(a)
  135.     local b2 = removeTrailingZeroes(b)
  136.    
  137.     if #a2 < #b2 then
  138.         return false
  139.     end
  140.     if #b2 < #a2 then
  141.         return false
  142.     end
  143.    
  144.     for i=#a2, 1, -1 do
  145.         if a2[i] > b2[i] then
  146.             return false
  147.         elseif a2[i] < b2[i] then
  148.             return false
  149.         end
  150.     end
  151.     return true
  152. end
  153.  
  154. cmp_ieq = function(a,b) -- InEQuality
  155.     local a2 = removeTrailingZeroes(a)
  156.     local b2 = removeTrailingZeroes(b)
  157.    
  158.     if #a2 < #b2 then
  159.         return true
  160.     end
  161.     if #b2 < #a2 then
  162.         return true
  163.     end
  164.    
  165.     for i=#a2, 1, -1 do
  166.         if a2[i] > b2[i] then
  167.             return true
  168.         elseif a2[i] < b2[i] then
  169.             return true
  170.         end
  171.     end
  172.     return false
  173. end
  174.  
  175. local function validateBigInt(a)
  176.     if type(a) ~= "table" then
  177.         return false
  178.     end
  179.     for i=1, #a do
  180.         if type(a[i]) ~= "number" then
  181.             return false
  182.         end
  183.     end
  184.     return true
  185. end
  186.  
  187. local function add_bigInt(a, b)
  188.     local cpy = copy(a)
  189.     local carry = 0
  190.     if cmp_gt(b, a) then
  191.         return add_bigInt(b,a)
  192.     end
  193.    
  194.     for i=1, #b do
  195.         local n = a[i] or 0
  196.         local m = b[i] or 0
  197.         cpy[i] = n+m+carry
  198.         if cpy[i] > 9 then
  199.             carry = 1 -- cpy[i] cannot be greater than 18
  200.             cpy[i] = cpy[i] % 10
  201.         else
  202.             carry = 0
  203.         end
  204.     end
  205.     if carry > 0 then
  206.         local n = cpy[ #b+1 ] or 0
  207.         cpy[ #b+1 ] = n+carry
  208.     end
  209.     return removeTrailingZeroes(cpy)
  210. end
  211.  
  212. local function sub_bigInt(a,b)
  213.     local cpy = copy(a)
  214.     local borrow = 0
  215.    
  216.     for i=1, #a do
  217.         local n = a[i] or 0
  218.         local n2 = b[i] or 0
  219.         cpy[i] = n-n2-borrow
  220.         if cpy[i] < 0 then
  221.             cpy[i] = 10+cpy[i]
  222.             borrow = 1
  223.         else
  224.             borrow = 0
  225.         end
  226.     end
  227.    
  228.     return removeTrailingZeroes(cpy)
  229. end
  230.  
  231. local function mul_bigInt(a,b)
  232.     local sum = {}
  233.     local tSum = {}
  234.     local carry = 0
  235.    
  236.     for i=1, #a do
  237.         carry = 0
  238.         sum[i] = {}
  239.         for j=1, #b do
  240.             sum[i][j] = (a[i]*b[j])+carry
  241.             if sum[i][j] > 9 then
  242.                 carry = math.floor( sum[i][j]/10 )
  243.                 sum[i][j] = sum[i][j] % 10
  244.                 --sum[i][j] = ( (sum[i][j]/10) - carry )*10
  245.             else
  246.                 carry = 0
  247.             end
  248.         end
  249.         if carry > 0 then
  250.             sum[i][#b+1] = carry
  251.         end
  252.         for j=2, i do
  253.             table.insert(sum[i], 1, 0) -- table.insert(bigInt, 1, 0) is equivalent to bigInt*10. Likewise, table.remove(bigInt, 1) is equivalent to bigInt/10. table.insert(bigInt, 1, x) is eqivalent to bigInt*10+x, assuming that x is a 1-digit number
  254.         end
  255.     end
  256.    
  257.     for i=1, #a+#b do
  258.         tSum[i] = 0
  259.     end
  260.     for i=1, #sum do
  261.         tSum = add_bigInt(tSum, sum[i])
  262.     end
  263.     return removeTrailingZeroes(tSum)
  264. end
  265.  
  266. local function div_bigInt(a,b)
  267.     local bringDown = {}
  268.     local quotient = {}
  269.    
  270.     for i=#a, 1, -1 do
  271.         table.insert(bringDown, 1, a[i])
  272.         if cmp_gteq(bringDown, b) then
  273.             local add = 0
  274.             while cmp_gteq(bringDown, b) do -- while bringDown >= b do
  275.                 bringDown = sub_bigInt(bringDown, b)
  276.                 add = add+1
  277.             end
  278.             table.insert(quotient, 1, add)
  279.         else
  280.             table.insert(quotient, 1, 0)
  281.         end
  282.     end
  283.     return removeTrailingZeroes(quotient), removeTrailingZeroes(bringDown)
  284. end
  285.  
  286. local function exp_bigInt(a,b) -- exponentation by squaring. This *should* work, no promises though.
  287.     if cmp_eq(b, 1) then
  288.         return a
  289.     elseif cmp_eq(mod(b, 2), 0) then
  290.         return exp_bigInt(mul(a,a), div(b,2))
  291.     elseif cmp_eq(mod(b, 2), 1) then
  292.         return mul(a, exp_bigInt(mul(a,a), div(sub(b,1),2)))
  293.     end
  294. end
  295.  
  296. function toBinary(a) -- Convert from a arbitrary precision decimal number to an arbitrary-length table of bits (least-significant bit first)
  297.     local bitTable = {}
  298.     local cpy = copy(a)
  299.    
  300.     while true do
  301.         local quot, rem = div_bigInt(cpy, {2})
  302.         cpy = quot
  303.         rem[1] = rem[1] or 0
  304.         table.insert(bitTable, rem[1])
  305.         --print(toStr(cpy).." "..toStr(rem))
  306.         if #cpy == 0 then
  307.             break
  308.         end
  309.     end
  310.     return bitTable
  311. end
  312.  
  313. function fromBinary(a) -- Convert from an arbitrary-length table of bits (from toBinary) to an arbitrary precision decimal number
  314.     local dec = {0}
  315.     for i=#a, 1, -1 do
  316.         dec = mul_bigInt(dec, {2})
  317.         dec = add_bigInt(dec, {a[i]})
  318.     end
  319.     return dec
  320. end
  321.  
  322. local function appendBits(i, sz) -- Appends bits to make #i match sz.
  323.     local cpy = copy(i)
  324.     for j=#i, sz-1 do
  325.         table.insert(cpy, 0)
  326.     end
  327.     return cpy
  328. end
  329.  
  330. function bitwiseLeftShift(a, i)
  331.     return mul(a, exp(2, i))
  332. end
  333.  
  334. function bitwiseRightShift(a, i)
  335.     local q = div(a, exp(2, i))
  336.     return q
  337. end
  338.  
  339. function bitwiseNOT(a)
  340.     local b = toBinary(a)
  341.     for i=1, #b do
  342.         if b[i] == 0 then
  343.             b[i] = 1
  344.         else
  345.             b[i] = 0
  346.         end
  347.     end
  348.     return fromBinary(b)
  349. end
  350.  
  351. function bitwiseXOR(a, b)
  352.     local a2 = toBinary(a)
  353.     local b2 = appendBits(toBinary(b), #a2)
  354.     if #a2 > #b2 then
  355.         return bitwiseXOR(b,a)
  356.     end
  357.     for i=1, #a2 do
  358.         if a2[i] == 1 and b2[i] == 1 then
  359.             a2[i] = 0
  360.         elseif a2[i] == 0 and b2[i] == 0 then
  361.             a2[i] = 0
  362.         else
  363.             a2[i] = 1
  364.         end
  365.     end
  366.     return fromBinary(a2)
  367. end
  368.  
  369. function bitwiseOR(a, b)
  370.     local a2 = toBinary(a)
  371.     local b2 = appendBits(toBinary(b), #a2)
  372.     if #a2 > #b2 then
  373.         return bitwiseOR(b,a)
  374.     end
  375.     for i=1, #a2 do
  376.         if a2[i] == 1 or b2[i] == 1 then
  377.             a2[i] = 1
  378.         else
  379.             a2[i] = 0
  380.         end
  381.     end
  382.     return fromBinary(a2)
  383. end
  384.  
  385. function bitwiseAND(a, b)
  386.     local a2 = toBinary(a)
  387.     local b2 = appendBits(toBinary(b), #a2)
  388.     if #a2 > #b2 then
  389.         return bitwiseAND(b,a)
  390.     end
  391.     for i=1, #a2 do
  392.         if a2[i] == 1 and b2[i] == 1 then
  393.             a2[i] = 1
  394.         else
  395.             a2[i] = 0
  396.         end
  397.     end
  398.     return fromBinary(a2)
  399. end
  400.  
  401. function add(a, b)
  402.     if type(a) == "number" then
  403.         a = toBigInt(a)
  404.     end
  405.     if type(b) == "number" then
  406.         b = toBigInt(b)
  407.     end
  408.     if validateBigInt(a) and validateBigInt(b) then
  409.         return add_bigInt(a,b)
  410.     end
  411. end
  412.  
  413. function sub(a, b)
  414.     if type(a) == "number" then
  415.         a = toBigInt(a)
  416.     end
  417.     if type(b) == "number" then
  418.         b = toBigInt(b)
  419.     end
  420.     if validateBigInt(a) and validateBigInt(b) then
  421.         return sub_bigInt(a,b)
  422.     end
  423. end
  424.  
  425. function mul(a, b)
  426.     if type(a) == "number" then
  427.         a = toBigInt(a)
  428.     end
  429.     if type(b) == "number" then
  430.         b = toBigInt(b)
  431.     end
  432.     if validateBigInt(a) and validateBigInt(b) then
  433.         return mul_bigInt(a,b)
  434.     end
  435. end
  436.  
  437. function div(a, b)
  438.     if type(a) == "number" then
  439.         a = toBigInt(a)
  440.     end
  441.     if type(b) == "number" then
  442.         b = toBigInt(b)
  443.     end
  444.     if validateBigInt(a) and validateBigInt(b) then
  445.         return div_bigInt(a,b)
  446.     end
  447. end
  448.  
  449. function mod(a, b)
  450.     if type(a) == "number" then
  451.         a = toBigInt(a)
  452.     end
  453.     if type(b) == "number" then
  454.         b = toBigInt(b)
  455.     end
  456.     if validateBigInt(a) and validateBigInt(b) then
  457.         local q, r = div_bigInt(a,b)
  458.         return r
  459.     end
  460. end
  461.  
  462. function exp(a,b)
  463.     if type(a) == "number" then
  464.         a = toBigInt(a)
  465.     end
  466.     if type(b) == "number" then
  467.         b = toBigInt(b)
  468.     end
  469.     if validateBigInt(a) and validateBigInt(b) then
  470.         return exp_bigInt(a,b)
  471.     end
  472. end
  473.  
  474. function toStr(a)
  475.     local str = ""
  476.     for i=#a, 1, -1 do
  477.         str = str..string.sub(tostring(a[i]), 1, 1)
  478.     end
  479.     return str
  480. end
  481.  
  482. function toBigInt(n) -- can take either a string composed of numbers (like "1237162721379627129638372") or a small integer (such as literal 18957 or 4*197163%2)
  483.     local n2 = {}
  484.     if type(n) == "number" then
  485.         while n > 0 do
  486.              table.insert(n2,  n%10)
  487.              n = math.floor(n/10)
  488.         end
  489.     elseif type(n) == "string" then
  490.         for i=1, #n do
  491.             local digit = tonumber(string.sub(n, i,i))
  492.             if digit then
  493.                 table.insert(n2, 1, digit)
  494.             end
  495.         end
  496.     end
  497.     return n2
  498. end
  499.  
  500. -- Long names for the functions:
  501. cmp_equality = cmp_eq
  502. cmp_inequality = cmp_ieq
  503. cmp_greater_than = cmp_gt
  504. cmp_greater_than_or_equal_to = cmp_gteq
  505. cmp_greater_than_equal_to = cmp_gteq
  506. cmp_less_than = cmp_lt
  507. cmp_less_than_or_equal_to = cmp_lteq
  508. cmp_less_than_equal_to = cmp_lteq
  509. bor = bitwiseBOR
  510. bxor = bitwiseXOR
  511. band = bitwiseAND
  512. bnot = bitwiseNOT
  513. blshift = bitwiseLeftShift
  514. brshift = bitwiseRightShift
  515. subtract = sub
  516. multiply = mul
  517. divide = div
  518. modulo = mod
  519. exponent = exp
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement