Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- --[[
- matrix v$(_VERSION) - matrix functions implemented with Lua tables
- local matrix = require 'matrix'
- m1 = matrix{{8,4,1},{6,8,3}}
- m2 = matrix{{-8,1,3},{5,2,1}}
- assert(m1 + m2 == matrix{{0,5,4},{11,10,4}})
- With simple matrices this script is quite useful, though for more
- exact calculations, one would probably use a program like Matlab instead.
- Matrices of size 100x100 can still be handled very well.
- The error for the determinant and the inverted matrix is around 10^-9
- with a 100x100 matrix and an element range from -100 to 100.
- Characteristics:
- - functions called via matrix.<function> should be able to handle
- any table matrix of structure t[i][j] = value
- - can handle a type of complex matrix
- - can handle symbolic matrices. (Symbolic matrices cannot be
- used with complex matrices.)
- - arithmetic functions do not change the matrix itself
- but build and return a new matrix
- - functions are intended to be light on checks
- since one gets a Lua error on incorrect use anyways
- - uses mainly Gauss-Jordan elimination
- - for Lua tables optimised determinant calculation (fast)
- but not invoking any checks for special types of matrices
- - vectors can be set up via vec1 = matrix{{ 1,2,3 }}^'T' or matrix{1,2,3}
- - vectors can be multiplied to a scalar via num = vec1^'T' * vec2
- where num will be a matrix with the result in mtx[1][1],
- or use num = vec1:scalar( vec2 ), where num is a number
- matrix function list:
- matrix.add
- matrix.columns
- matrix.concath
- matrix.concatv
- matrix.copy
- matrix.cross
- matrix.det
- matrix.div
- matrix.divnum
- matrix.dogauss
- matrix.elementstostring
- matrix.getelement
- matrix.gsub
- matrix.invert
- matrix.ipairs
- matrix.latex
- matrix.len
- matrix.mul
- matrix.mulnum
- matrix:new
- matrix.normf
- matrix.normmax
- matrix.pow
- matrix.print
- matrix.random
- matrix.replace
- matrix.root
- matrix.rotl
- matrix.rotr
- matrix.round
- matrix.rows
- matrix.scalar
- matrix.setelement
- matrix.size
- matrix.solve
- matrix.sqrt
- matrix.sub
- matrix.subm
- matrix.tostring
- matrix.transpose
- matrix.type
- See code and test_matrix.lua.
- None (other than Lua 5.1 or 5.2). May be used with complex.lua.
- ./
- cd tmp/*
- luarocks make
- Licensed under the same terms as Lua itself.
- Developers:
- Michael Lutz (chillcode) - original author
- David Manura
- --]]
- --////////////
- --// matrix //
- --////////////
- local matrix = {_TYPE='module', _NAME='matrix', _VERSION=''}
- -- access to the metatable we set at the end of the file
- local matrix_meta = {}
- --/////////////////////////////
- --// Get 'new' matrix object //
- --/////////////////////////////
- --// matrix:new ( rows [, columns [, value]] )
- -- if rows is a table then sets rows as matrix
- -- if rows is a table of structure {1,2,3} then it sets it as a vector matrix
- -- if rows and columns are given and are numbers, returns a matrix with size rowsxcolumns
- -- if num is given then returns a matrix with given size and all values set to num
- -- if rows is given as number and columns is "I", will return an identity matrix of size rowsxrows
- function matrix:new( rows, columns, value )
- -- check for given matrix
- if type( rows ) == "table" then
- -- check for vector
- if type(rows[1]) ~= "table" then -- expect a vector
- return setmetatable( {{rows[1]},{rows[2]},{rows[3]}},matrix_meta )
- end
- return setmetatable( rows,matrix_meta )
- end
- -- get matrix table
- local mtx = {}
- local value = value or 0
- -- build identity matrix of given rows
- if columns == "I" then
- for i = 1,rows do
- mtx[i] = {}
- for j = 1,rows do
- if i == j then
- mtx[i][j] = 1
- else
- mtx[i][j] = 0
- end
- end
- end
- -- build new matrix
- else
- for i = 1,rows do
- mtx[i] = {}
- for j = 1,columns do
- mtx[i][j] = value
- end
- end
- end
- -- return matrix with shared metatable
- return setmetatable( mtx,matrix_meta )
- end
- --// matrix ( rows [, comlumns [, value]] )
- -- set __call behaviour of matrix
- -- for matrix( ... ) as ... )
- setmetatable( matrix, { __call = function( ... ) return ... ) end } )
- -- functions are designed to be light on checks
- -- so we get Lua errors instead on wrong input
- -- matrix.<functions> should handle any table of structure t[i][j] = value
- -- we always return a matrix with scripts metatable
- -- cause its faster than setmetatable( mtx, getmetatable( input matrix ) )
- --///////////////////////////////
- --// matrix 'matrix' functions //
- --///////////////////////////////
- --// for real, complex and symbolic matrices //--
- -- note: real and complex matrices may be added, subtracted, etc.
- -- real and symbolic matrices may also be added, subtracted, etc.
- -- but one should avoid using symbolic matrices with complex ones
- -- since it is not clear which metatable then is used
- --// matrix.add ( m1, m2 )
- -- Add two matrices; m2 may be of bigger size than m1
- function matrix.add( m1, m2 )
- local mtx = {}
- for i = 1,#m1 do
- local m3i = {}
- mtx[i] = m3i
- for j = 1,#m1[1] do
- m3i[j] = m1[i][j] + m2[i][j]
- end
- end
- return setmetatable( mtx, matrix_meta )
- end
- --// matrix.sub ( m1 ,m2 )
- -- Subtract two matrices; m2 may be of bigger size than m1
- function matrix.sub( m1, m2 )
- local mtx = {}
- for i = 1,#m1 do
- local m3i = {}
- mtx[i] = m3i
- for j = 1,#m1[1] do
- m3i[j] = m1[i][j] - m2[i][j]
- end
- end
- return setmetatable( mtx, matrix_meta )
- end
- --// matrix.mul ( m1, m2 )
- -- Multiply two matrices; m1 columns must be equal to m2 rows
- -- e.g. #m1[1] == #m2
- function matrix.mul( m1, m2 )
- -- multiply rows with columns
- local mtx = {}
- for i = 1,#m1 do
- mtx[i] = {}
- for j = 1,#m2[1] do
- local num = m1[i][1] * m2[1][j]
- for n = 2,#m1[1] do
- num = num + m1[i][n] * m2[n][j]
- end
- mtx[i][j] = num
- end
- end
- return setmetatable( mtx, matrix_meta )
- end
- --// matrix.div ( m1, m2 )
- -- Divide two matrices; m1 columns must be equal to m2 rows
- -- m2 must be square, to be inverted,
- -- if that fails returns the rank of m2 as second argument
- -- e.g. #m1[1] == #m2; #m2 == #m2[1]
- function matrix.div( m1, m2 )
- local rank; m2,rank = matrix.invert( m2 )
- if not m2 then return m2, rank end -- singular
- return matrix.mul( m1, m2 )
- end
- --// matrix.mulnum ( m1, num )
- -- Multiply matrix with a number
- -- num may be of type 'number' or 'complex number'
- -- strings get converted to complex number, if that fails then to symbol
- function matrix.mulnum( m1, num )
- local mtx = {}
- -- multiply elements with number
- for i = 1,#m1 do
- mtx[i] = {}
- for j = 1,#m1[1] do
- mtx[i][j] = m1[i][j] * num
- end
- end
- return setmetatable( mtx, matrix_meta )
- end
- --// matrix.divnum ( m1, num )
- -- Divide matrix by a number
- -- num may be of type 'number' or 'complex number'
- -- strings get converted to complex number, if that fails then to symbol
- function matrix.divnum( m1, num )
- local mtx = {}
- -- divide elements by number
- for i = 1,#m1 do
- local mtxi = {}
- mtx[i] = mtxi
- for j = 1,#m1[1] do
- mtxi[j] = m1[i][j] / num
- end
- end
- return setmetatable( mtx, matrix_meta )
- end
- --// for real and complex matrices only //--
- --// matrix.pow ( m1, num )
- -- Power of matrix; mtx^(num)
- -- num is an integer and may be negative
- -- m1 has to be square
- -- if num is negative and inverting m1 fails
- -- returns the rank of matrix m1 as second argument
- function matrix.pow( m1, num )
- assert(num == math.floor(num), "exponent not an integer")
- if num == 0 then
- return matrix:new( #m1,"I" )
- end
- if num < 0 then
- local rank; m1,rank = matrix.invert( m1 )
- if not m1 then return m1, rank end -- singular
- num = -num
- end
- local mtx = matrix.copy( m1 )
- for i = 2,num do
- mtx = matrix.mul( mtx,m1 )
- end
- return mtx
- end
- local function number_norm2(x)
- return x * x
- end
- --// matrix.det ( m1 )
- -- Calculate the determinant of a matrix
- -- m1 needs to be square
- -- Can calc the det for symbolic matrices up to 3x3 too
- -- The function to calculate matrices bigger 3x3
- -- is quite fast and for matrices of medium size ~(100x100)
- -- and average values quite accurate
- -- here we try to get the nearest element to |1|, (smallest pivot element)
- -- os that usually we have |mtx[i][j]/subdet| > 1 or mtx[i][j];
- -- with complex matrices we use the complex.abs function to check if it is bigger or smaller
- function matrix.det( m1 )
- -- check if matrix is quadratic
- assert(#m1 == #m1[1], "matrix not square")
- local size = #m1
- if size == 1 then
- return m1[1][1]
- end
- if size == 2 then
- return m1[1][1]*m1[2][2] - m1[2][1]*m1[1][2]
- end
- if size == 3 then
- return ( m1[1][1]*m1[2][2]*m1[3][3] + m1[1][2]*m1[2][3]*m1[3][1] + m1[1][3]*m1[2][1]*m1[3][2]
- - m1[1][3]*m1[2][2]*m1[3][1] - m1[1][1]*m1[2][3]*m1[3][2] - m1[1][2]*m1[2][1]*m1[3][3] )
- end
- --// no symbolic matrix supported below here
- local e = m1[1][1]
- local zero = type(e) == "table" and or 0
- local norm2 = type(e) == "table" and e.norm2 or number_norm2
- --// matrix is bigger than 3x3
- -- get determinant
- -- using Gauss elimination and Laplace
- -- start eliminating from below better for removals
- -- get copy of matrix, set initial determinant
- local mtx = matrix.copy( m1 )
- local det = 1
- -- get det up to the last element
- for j = 1,#mtx[1] do
- -- get smallest element so that |factor| > 1
- -- and set it as last element
- local rows = #mtx
- local subdet,xrow
- for i = 1,rows do
- -- get element
- local e = mtx[i][j]
- -- if no subdet has been found
- if not subdet then
- -- check if element it is not zero
- if e ~= zero then
- -- use element as new subdet
- subdet,xrow = e,i
- end
- -- check for elements nearest to 1 or -1
- elseif e ~= zero and math.abs(norm2(e)-1) < math.abs(norm2(subdet)-1) then
- subdet,xrow = e,i
- end
- end
- -- only cary on if subdet is found
- if subdet then
- -- check if xrow is the last row,
- -- else switch lines and multiply det by -1
- if xrow ~= rows then
- mtx[rows],mtx[xrow] = mtx[xrow],mtx[rows]
- det = -det
- end
- -- traverse all fields setting element to zero
- -- we don't set to zero cause we don't use that column anymore then anyways
- for i = 1,rows-1 do
- -- factor is the dividor of the first element
- -- if element is not already zero
- if mtx[i][j] ~= zero then
- local factor = mtx[i][j]/subdet
- -- update all remaining fields of the matrix, with value from xrow
- for n = j+1,#mtx[1] do
- mtx[i][n] = mtx[i][n] - factor * mtx[rows][n]
- end
- end
- end
- -- update determinant and remove row
- if math.fmod( rows,2 ) == 0 then
- det = -det
- end
- det = det * subdet
- table.remove( mtx )
- else
- -- break here table det is 0
- return det * 0
- end
- end
- -- det ready to return
- return det
- end
- --// matrix.dogauss ( mtx )
- -- Gauss elimination, Gauss-Jordan Method
- -- this function changes the matrix itself
- -- returns on success: true,
- -- returns on failure: false,'rank of matrix'
- -- locals
- -- checking here for the element nearest but not equal to zero (smallest pivot element).
- -- This way the `factor` in `dogauss` will be >= 1, which
- -- can give better results.
- local pivotOk = function( mtx,i,j,norm2 )
- -- find min value
- local iMin
- local normMin = math.huge
- for _i = i,#mtx do
- local e = mtx[_i][j]
- local norm = math.abs(norm2(e))
- if norm > 0 and norm < normMin then
- iMin = _i
- normMin = norm
- end
- end
- if iMin then
- -- switch lines if not in position.
- if iMin ~= i then
- mtx[i],mtx[iMin] = mtx[iMin],mtx[i]
- end
- return true
- end
- return false
- end
- local function copy(x)
- return type(x) == "table" and x.copy(x) or x
- end
- -- note: in --// ... //-- we have a way that does no divison,
- -- however with big number and matrices we get problems since we do no reducing
- function matrix.dogauss( mtx )
- local e = mtx[1][1]
- local zero = type(e) == "table" and or 0
- local one = type(e) == "table" and or 1
- local norm2 = type(e) == "table" and e.norm2 or number_norm2
- local rows,columns = #mtx,#mtx[1]
- -- stairs left -> right
- for j = 1,rows do
- -- check if element can be setted to one
- if pivotOk( mtx,j,j,norm2 ) then
- -- start parsing rows
- for i = j+1,rows do
- -- check if element is not already zero
- if mtx[i][j] ~= zero then
- -- we may add x*otherline row, to set element to zero
- -- tozero - x*mtx[j][j] = 0; x = tozero/mtx[j][j]
- local factor = mtx[i][j]/mtx[j][j]
- --// this should not be used although it does no division,
- -- yet with big matrices (since we do no reducing and other things)
- -- we get too big numbers
- --local factor1,factor2 = mtx[i][j],mtx[j][j] //--
- mtx[i][j] = copy(zero)
- for _j = j+1,columns do
- --// mtx[i][_j] = mtx[i][_j] * factor2 - factor1 * mtx[j][_j] //--
- mtx[i][_j] = mtx[i][_j] - factor * mtx[j][_j]
- end
- end
- end
- else
- -- return false and the rank of the matrix
- return false,j-1
- end
- end
- -- stairs right <- left
- for j = rows,1,-1 do
- -- set element to one
- -- do division here
- local div = mtx[j][j]
- for _j = j+1,columns do
- mtx[j][_j] = mtx[j][_j] / div
- end
- -- start parsing rows
- for i = j-1,1,-1 do
- -- check if element is not already zero
- if mtx[i][j] ~= zero then
- local factor = mtx[i][j]
- for _j = j+1,columns do
- mtx[i][_j] = mtx[i][_j] - factor * mtx[j][_j]
- end
- mtx[i][j] = copy(zero)
- end
- end
- mtx[j][j] = copy(one)
- end
- return true
- end
- --// matrix.invert ( m1 )
- -- Get the inverted matrix or m1
- -- matrix must be square and not singular
- -- on success: returns inverted matrix
- -- on failure: returns nil,'rank of matrix'
- function matrix.invert( m1 )
- assert(#m1 == #m1[1], "matrix not square")
- local mtx = matrix.copy( m1 )
- local ident = setmetatable( {},matrix_meta )
- local e = m1[1][1]
- local zero = type(e) == "table" and or 0
- local one = type(e) == "table" and or 1
- for i = 1,#m1 do
- local identi = {}
- ident[i] = identi
- for j = 1,#m1 do
- identi[j] = copy((i == j) and one or zero)
- end
- end
- mtx = matrix.concath( mtx,ident )
- local done,rank = matrix.dogauss( mtx )
- if done then
- return matrix.subm( mtx, 1,(#mtx[1]/2)+1,#mtx,#mtx[1] )
- else
- return nil,rank
- end
- end
- --// matrix.sqrt ( m1 [,iters] )
- -- calculate the square root of a matrix using "Denman Beavers square root iteration"
- -- condition: matrix rows == matrix columns; must have a invers matrix and a square root
- -- if called without additional arguments, the function finds the first nearest square root to
- -- input matrix, there are others but the error between them is very small
- -- if called with agument iters, the function will return the matrix by number of iterations
- -- the script returns:
- -- as first argument, matrix^.5
- -- as second argument, matrix^-.5
- -- as third argument, the average error between (matrix^.5)^2-inputmatrix
- -- you have to determin for yourself if the result is sufficent enough for you
- -- local average error
- local function get_abs_avg( m1, m2 )
- local dist = 0
- local e = m1[1][1]
- local abs = type(e) == "table" and e.abs or math.abs
- for i=1,#m1 do
- for j=1,#m1[1] do
- dist = dist + abs(m1[i][j]-m2[i][j])
- end
- end
- -- norm by numbers of entries
- return dist/(#m1*2)
- end
- -- square root function
- function matrix.sqrt( m1, iters )
- assert(#m1 == #m1[1], "matrix not square")
- local iters = iters or math.huge
- local y = matrix.copy( m1 )
- local z = matrix(#y, 'I')
- local dist = math.huge
- -- iterate, and get the average error
- for n=1,iters do
- local lasty,lastz = y,z
- -- calc square root
- -- y, z = (1/2)*(y + z^-1), (1/2)*(z + y^-1)
- y, z = matrix.divnum((matrix.add(y,matrix.invert(z))),2),
- matrix.divnum((matrix.add(z,matrix.invert(y))),2)
- local dist1 = get_abs_avg(y,lasty)
- if iters == math.huge then
- if dist1 >= dist then
- return lasty,lastz,get_abs_avg(matrix.mul(lasty,lasty),m1)
- end
- end
- dist = dist1
- end
- return y,z,get_abs_avg(matrix.mul(y,y),m1)
- end
- --// matrix.root ( m1, root [,iters] )
- -- calculate any root of a matrix
- -- source:
- -- m1 and root have to be given;(m1 = matrix, root = number)
- -- conditions same as matrix.sqrt
- -- returns same values as matrix.sqrt
- function matrix.root( m1, root, iters )
- assert(#m1 == #m1[1], "matrix not square")
- local iters = iters or math.huge
- local mx = matrix.copy( m1 )
- local my = matrix.mul(mx:invert(),mx:pow(root-1))
- local dist = math.huge
- -- iterate, and get the average error
- for n=1,iters do
- local lastx,lasty = mx,my
- -- calc root of matrix
- --mx,my = ((p-1)*mx + my^-1)/p,
- -- ((((p-1)*my + mx^-1)/p)*my^-1)^(p-2) *
- -- ((p-1)*my + mx^-1)/p
- mx,my = mx:mulnum(root-1):add(my:invert()):divnum(root),
- my:mulnum(root-1):add(mx:invert()):divnum(root)
- :mul(my:invert():pow(root-2)):mul(my:mulnum(root-1)
- :add(mx:invert())):divnum(root)
- local dist1 = get_abs_avg(mx,lastx)
- if iters == math.huge then
- if dist1 >= dist then
- return lastx,lasty,get_abs_avg(matrix.pow(lastx,root),m1)
- end
- end
- dist = dist1
- end
- return mx,my,get_abs_avg(matrix.pow(mx,root),m1)
- end
- --// Norm functions //--
- --// matrix.normf ( mtx )
- -- calculates the Frobenius norm of the matrix.
- -- ||mtx||_F = sqrt(SUM_{i,j} |a_{i,j}|^2)
- --
- function matrix.normf(mtx)
- local mtype = matrix.type(mtx)
- local result = 0
- for i = 1,#mtx do
- for j = 1,#mtx[1] do
- local e = mtx[i][j]
- if mtype ~= "number" then e = e:abs() end
- result = result + e^2
- end
- end
- local sqrt = (type(result) == "number") and math.sqrt or result.sqrt
- return sqrt(result)
- end
- --// matrix.normmax ( mtx )
- -- calculates the max norm of the matrix.
- -- ||mtx||_{max} = max{|a_{i,j}|}
- -- Does not work with symbolic matrices
- --
- function matrix.normmax(mtx)
- local abs = (matrix.type(mtx) == "number") and math.abs or mtx[1][1].abs
- local result = 0
- for i = 1,#mtx do
- for j = 1,#mtx[1] do
- local e = abs(mtx[i][j])
- if e > result then result = e end
- end
- end
- return result
- end
- --// only for number and complex type //--
- -- Functions changing the matrix itself
- --// matrix.round ( mtx [, idp] )
- -- perform round on elements
- local numround = function( num,mult )
- return math.floor( num * mult + 0.5 ) / mult
- end
- local tround = function( t,mult )
- for i,v in ipairs(t) do
- t[i] = math.floor( v * mult + 0.5 ) / mult
- end
- return t
- end
- function matrix.round( mtx, idp )
- local mult = 10^( idp or 0 )
- local fround = matrix.type( mtx ) == "number" and numround or tround
- for i = 1,#mtx do
- for j = 1,#mtx[1] do
- mtx[i][j] = fround(mtx[i][j],mult)
- end
- end
- return mtx
- end
- --// matrix.random( mtx [,start] [, stop] [, idip] )
- -- fillmatrix with random values
- local numfill = function( _,start,stop,idp )
- return math.random( start,stop ) / idp
- end
- local tfill = function( t,start,stop,idp )
- for i in ipairs(t) do
- t[i] = math.random( start,stop ) / idp
- end
- return t
- end
- function matrix.random( mtx,start,stop,idp )
- local start,stop,idp = start or -10,stop or 10,idp or 1
- local ffill = matrix.type( mtx ) == "number" and numfill or tfill
- for i = 1,#mtx do
- for j = 1,#mtx[1] do
- mtx[i][j] = ffill( mtx[i][j], start, stop, idp )
- end
- end
- return mtx
- end
- --//////////////////////////////
- --// Object Utility Functions //
- --//////////////////////////////
- --// for all types and matrices //--
- --// matrix.type ( mtx )
- -- get type of matrix, normal/complex/symbol or tensor
- function matrix.type( mtx )
- local e = mtx[1][1]
- if type(e) == "table" then
- if e.type then
- return e:type()
- end
- return "tensor"
- end
- return "number"
- end
- -- local functions to copy matrix values
- local num_copy = function( num )
- return num
- end
- local t_copy = function( t )
- local newt = setmetatable( {}, getmetatable( t ) )
- for i,v in ipairs( t ) do
- newt[i] = v
- end
- return newt
- end
- --// matrix.copy ( m1 )
- -- Copy a matrix
- -- simple copy, one can write other functions oneself
- function matrix.copy( m1 )
- local docopy = matrix.type( m1 ) == "number" and num_copy or t_copy
- local mtx = {}
- for i = 1,#m1[1] do
- mtx[i] = {}
- for j = 1,#m1 do
- mtx[i][j] = docopy( m1[i][j] )
- end
- end
- return setmetatable( mtx, matrix_meta )
- end
- --// matrix.transpose ( m1 )
- -- Transpose a matrix
- -- switch rows and columns
- function matrix.transpose( m1 )
- local docopy = matrix.type( m1 ) == "number" and num_copy or t_copy
- local mtx = {}
- for i = 1,#m1[1] do
- mtx[i] = {}
- for j = 1,#m1 do
- mtx[i][j] = docopy( m1[j][i] )
- end
- end
- return setmetatable( mtx, matrix_meta )
- end
- --// matrix.subm ( m1, i1, j1, i2, j2 )
- -- Submatrix out of a matrix
- -- input: i1,j1,i2,j2
- -- i1,j1 are the start element
- -- i2,j2 are the end element
- -- condition: i1,j1,i2,j2 are elements of the matrix
- function matrix.subm( m1,i1,j1,i2,j2 )
- local docopy = matrix.type( m1 ) == "number" and num_copy or t_copy
- local mtx = {}
- for i = i1,i2 do
- local _i = i-i1+1
- mtx[_i] = {}
- for j = j1,j2 do
- local _j = j-j1+1
- mtx[_i][_j] = docopy( m1[i][j] )
- end
- end
- return setmetatable( mtx, matrix_meta )
- end
- --// matrix.concath( m1, m2 )
- -- Concatenate two matrices, horizontal
- -- will return m1m2; rows have to be the same
- -- e.g.: #m1 == #m2
- function matrix.concath( m1,m2 )
- assert(#m1 == #m2, "matrix size mismatch")
- local docopy = matrix.type( m1 ) == "number" and num_copy or t_copy
- local mtx = {}
- local offset = #m1[1]
- for i = 1,#m1 do
- mtx[i] = {}
- for j = 1,offset do
- mtx[i][j] = docopy( m1[i][j] )
- end
- for j = 1,#m2[1] do
- mtx[i][j+offset] = docopy( m2[i][j] )
- end
- end
- return setmetatable( mtx, matrix_meta )
- end
- --// matrix.concatv ( m1, m2 )
- -- Concatenate two matrices, vertical
- -- will return m1
- -- m2
- -- columns have to be the same; e.g.: #m1[1] == #m2[1]
- function matrix.concatv( m1,m2 )
- assert(#m1[1] == #m2[1], "matrix size mismatch")
- local docopy = matrix.type( m1 ) == "number" and num_copy or t_copy
- local mtx = {}
- for i = 1,#m1 do
- mtx[i] = {}
- for j = 1,#m1[1] do
- mtx[i][j] = docopy( m1[i][j] )
- end
- end
- local offset = #mtx
- for i = 1,#m2 do
- local _i = i + offset
- mtx[_i] = {}
- for j = 1,#m2[1] do
- mtx[_i][j] = docopy( m2[i][j] )
- end
- end
- return setmetatable( mtx, matrix_meta )
- end
- --// matrix.rotl ( m1 )
- -- Rotate Left, 90 degrees
- function matrix.rotl( m1 )
- local mtx = matrix:new( #m1[1],#m1 )
- local docopy = matrix.type( m1 ) == "number" and num_copy or t_copy
- for i = 1,#m1 do
- for j = 1,#m1[1] do
- mtx[#m1[1]-j+1][i] = docopy( m1[i][j] )
- end
- end
- return mtx
- end
- --// matrix.rotr ( m1 )
- -- Rotate Right, 90 degrees
- function matrix.rotr( m1 )
- local mtx = matrix:new( #m1[1],#m1 )
- local docopy = matrix.type( m1 ) == "number" and num_copy or t_copy
- for i = 1,#m1 do
- for j = 1,#m1[1] do
- mtx[j][#m1-i+1] = docopy( m1[i][j] )
- end
- end
- return mtx
- end
- local function tensor_tostring( t,fstr )
- if not fstr then return "["..table.concat(t,",").."]" end
- local tval = {}
- for i,v in ipairs( t ) do
- tval[i] = string.format( fstr,v )
- end
- return "["..table.concat(tval,",").."]"
- end
- local function number_tostring( e,fstr )
- return fstr and string.format( fstr,e ) or e
- end
- --// matrix.tostring ( mtx, formatstr )
- -- tostring function
- function matrix.tostring( mtx, formatstr )
- local ts = {}
- local mtype = matrix.type( mtx )
- local e = mtx[1][1]
- local tostring = mtype == "tensor" and tensor_tostring or
- type(e) == "table" and e.tostring or number_tostring
- for i = 1,#mtx do
- local tstr = {}
- for j = 1,#mtx[1] do
- tstr[j] = tostring(mtx[i][j],formatstr)
- end
- ts[i] = table.concat(tstr, " ")
- end
- return table.concat(ts, " | ")
- end
- --// matrix.print ( mtx [, formatstr] )
- -- print out the matrix, just calls tostring
- function matrix.print( ... )
- print( matrix.tostring( ... ) )
- end
- --// matrix.latex ( mtx [, align] )
- -- LaTeX output
- function matrix.latex( mtx, align )
- -- align : option to align the elements
- -- c = center; l = left; r = right
- -- \usepackage{dcolumn}; D{.}{,}{-1}; aligns number by . replaces it with ,
- local align = align or "c"
- local str = "$\\left( \\begin{array}{"..string.rep( align, #mtx[1] ).."}\n"
- local getstr = matrix.type( mtx ) == "tensor" and tensor_tostring or number_tostring
- for i = 1,#mtx do
- str = str.."\t"..getstr(mtx[i][1])
- for j = 2,#mtx[1] do
- str = str.." & "..getstr(mtx[i][j])
- end
- -- close line
- if i == #mtx then
- str = str.."\n"
- else
- str = str.." \\\\\n"
- end
- end
- return str.."\\end{array} \\right)$"
- end
- --// Functions not changing the matrix
- --// matrix.rows ( mtx )
- -- return number of rows
- function matrix.rows( mtx )
- return #mtx
- end
- --// matrix.columns ( mtx )
- -- return number of columns
- function matrix.columns( mtx )
- return #mtx[1]
- end
- --// matrix.size ( mtx )
- -- get matrix size as string rows,columns
- function matrix.size( mtx )
- if matrix.type( mtx ) == "tensor" then
- return #mtx,#mtx[1],#mtx[1][1]
- end
- return #mtx,#mtx[1]
- end
- --// matrix.getelement ( mtx, i, j )
- -- return specific element ( row,column )
- -- returns element on success and nil on failure
- function matrix.getelement( mtx,i,j )
- if mtx[i] and mtx[i][j] then
- return mtx[i][j]
- end
- end
- --// matrix.setelement( mtx, i, j, value )
- -- set an element ( i, j, value )
- -- returns 1 on success and nil on failure
- function matrix.setelement( mtx,i,j,value )
- if matrix.getelement( mtx,i,j ) then
- -- check if value type is number
- mtx[i][j] = value
- return 1
- end
- end
- --// matrix.ipairs ( mtx )
- -- iteration, same for complex
- function matrix.ipairs( mtx )
- local i,j,rows,columns = 1,0,#mtx,#mtx[1]
- local function iter()
- j = j + 1
- if j > columns then -- return first element from next row
- i,j = i + 1,1
- end
- if i <= rows then
- return i,j
- end
- end
- return iter
- end
- --///////////////////////////////
- --// matrix 'vector' functions //
- --///////////////////////////////
- -- a vector is defined as a 3x1 matrix
- -- get a vector; vec = matrix{{ 1,2,3 }}^'T'
- --// matrix.scalar ( m1, m2 )
- -- returns the Scalar Product of two 3x1 matrices (vectors)
- function matrix.scalar( m1, m2 )
- return m1[1][1]*m2[1][1] + m1[2][1]*m2[2][1] + m1[3][1]*m2[3][1]
- end
- --// matrix.cross ( m1, m2 )
- -- returns the Cross Product of two 3x1 matrices (vectors)
- function matrix.cross( m1, m2 )
- local mtx = {}
- mtx[1] = { m1[2][1]*m2[3][1] - m1[3][1]*m2[2][1] }
- mtx[2] = { m1[3][1]*m2[1][1] - m1[1][1]*m2[3][1] }
- mtx[3] = { m1[1][1]*m2[2][1] - m1[2][1]*m2[1][1] }
- return setmetatable( mtx, matrix_meta )
- end
- --// matrix.len ( m1 )
- -- returns the Length of a 3x1 matrix (vector)
- function matrix.len( m1 )
- return math.sqrt( m1[1][1]^2 + m1[2][1]^2 + m1[3][1]^2 )
- end
- --// matrix.replace (mtx, func, ...)
- -- for each element e in the matrix mtx, replace it with func(mtx, ...).
- function matrix.replace( m1, func, ... )
- local mtx = {}
- for i = 1,#m1 do
- local m1i = m1[i]
- local mtxi = {}
- for j = 1,#m1i do
- mtxi[j] = func( m1i[j], ... )
- end
- mtx[i] = mtxi
- end
- return setmetatable( mtx, matrix_meta )
- end
- --// matrix.remcomplex ( mtx )
- -- set the matrix elements to strings
- -- IMPROVE: tostring v.s. tostringelements confusing
- function matrix.elementstostrings( mtx )
- local e = mtx[1][1]
- local tostring = type(e) == "table" and e.tostring or tostring
- return matrix.replace(mtx, tostring)
- end
- --// matrix.solve ( m1 )
- -- solve; tries to solve a symbolic matrix to a number
- function matrix.solve( m1 )
- assert( matrix.type( m1 ) == "symbol", "matrix not of type 'symbol'" )
- local mtx = {}
- for i = 1,#m1 do
- mtx[i] = {}
- for j = 1,#m1[1] do
- mtx[i][j] = tonumber( loadstring( "return "..m1[i][j][1] )() )
- end
- end
- return setmetatable( mtx, matrix_meta )
- end
- --////////////////////////--
- --////////////////////////--
- --// MetaTable
- -- as we declaired on top of the page
- -- local/shared metatable
- -- matrix_meta
- -- note '...' is always faster than 'arg1,arg2,...' if it can be used
- -- Set add "+" behaviour
- matrix_meta.__add = function( ... )
- return matrix.add( ... )
- end
- -- Set subtract "-" behaviour
- matrix_meta.__sub = function( ... )
- return matrix.sub( ... )
- end
- -- Set multiply "*" behaviour
- matrix_meta.__mul = function( m1,m2 )
- if getmetatable( m1 ) ~= matrix_meta then
- return matrix.mulnum( m2,m1 )
- elseif getmetatable( m2 ) ~= matrix_meta then
- return matrix.mulnum( m1,m2 )
- end
- return matrix.mul( m1,m2 )
- end
- -- Set division "/" behaviour
- matrix_meta.__div = function( m1,m2 )
- if getmetatable( m1 ) ~= matrix_meta then
- return matrix.mulnum( matrix.invert(m2),m1 )
- elseif getmetatable( m2 ) ~= matrix_meta then
- return matrix.divnum( m1,m2 )
- end
- return matrix.div( m1,m2 )
- end
- -- Set unary minus "-" behavior
- matrix_meta.__unm = function( mtx )
- return matrix.mulnum( mtx,-1 )
- end
- -- Set power "^" behaviour
- -- if opt is any integer number will do mtx^opt
- -- (returning nil if answer doesn't exist)
- -- if opt is 'T' then it will return the transpose matrix
- -- only for complex:
- -- if opt is '*' then it returns the complex conjugate matrix
- local option = {
- -- only for complex
- ["*"] = function( m1 ) return matrix.conjugate( m1 ) end,
- -- for both
- ["T"] = function( m1 ) return matrix.transpose( m1 ) end,
- }
- matrix_meta.__pow = function( m1, opt )
- return option[opt] and option[opt]( m1 ) or matrix.pow( m1,opt )
- end
- -- Set equal "==" behaviour
- matrix_meta.__eq = function( m1, m2 )
- -- check same type
- if matrix.type( m1 ) ~= matrix.type( m2 ) then
- return false
- end
- -- check same size
- if #m1 ~= #m2 or #m1[1] ~= #m2[1] then
- return false
- end
- -- check elements equal
- for i = 1,#m1 do
- for j = 1,#m1[1] do
- if m1[i][j] ~= m2[i][j] then
- return false
- end
- end
- end
- return true
- end
- -- Set tostring "tostring( mtx )" behaviour
- matrix_meta.__tostring = function( ... )
- return matrix.tostring( ... )
- end
- -- set __call "mtx( [formatstr] )" behaviour, mtx [, formatstr]
- matrix_meta.__call = function( ... )
- matrix.print( ... )
- end
- --// __index handling
- matrix_meta.__index = {}
- for k,v in pairs( matrix ) do
- matrix_meta.__index[k] = v
- end
- --/////////////////////////////////
- --// symbol class implementation
- --/////////////////////////////////
- -- access to the symbolic metatable
- local symbol_meta = {}; symbol_meta.__index = symbol_meta
- local symbol = symbol_meta
- function
- return setmetatable({tostring(o)}, symbol_meta)
- end
- =
- -- symbol( arg )
- -- same as arg )
- -- set __call behaviour of symbol
- setmetatable( symbol_meta, { __call = function( _,s ) return s ) end } )
- -- Converts object to string, optionally with formatting.
- function symbol_meta.tostring( e,fstr )
- return string.format( fstr,e[1] )
- end
- -- Returns "symbol" if object is a symbol type, else nothing.
- function symbol_meta:type()
- if getmetatable(self) == symbol_meta then
- return "symbol"
- end
- end
- -- Performs string.gsub on symbol.
- -- for use in matrix.replace
- function symbol_meta:gsub(from, to)
- return string.gsub( self[1],from,to ) )
- end
- -- creates function that replaces one letter by something else
- -- makereplacer( "a",4,"b",7, ... )(x)
- -- will replace a with 4 and b with 7 in symbol x.
- -- for use in matrix.replace
- function symbol_meta.makereplacer( ... )
- local tosub = {}
- local args = {...}
- for i = 1,#args,2 do
- tosub[args[i]] = args[i+1]
- end
- local function func( a ) return tosub[a] or a end
- return function(sym)
- return string.gsub( sym[1], "%a", func ) )
- end
- end
- -- applies abs function to symbol
- function symbol_meta.abs(a)
- return"(" .. a[1] .. "):abs()")
- end
- -- applies sqrt function to symbol
- function symbol_meta.sqrt(a)
- return"(" .. a[1] .. "):sqrt()")
- end
- function symbol_meta.__add(a,b)
- return .. "+" .. b)
- end
- function symbol_meta.__sub(a,b)
- return .. "-" .. b)
- end
- function symbol_meta.__mul(a,b)
- return"(" .. a .. ")*(" .. b .. ")")
- end
- function symbol_meta.__div(a,b)
- return"(" .. a .. ")/(" .. b .. ")")
- end
- function symbol_meta.__pow(a,b)
- return"(" .. a .. ")^(" .. b .. ")")
- end
- function symbol_meta.__eq(a,b)
- return a[1] == b[1]
- end
- function symbol_meta.__tostring(a)
- return a[1]
- end
- function symbol_meta.__concat(a,b)
- return tostring(a) .. tostring(b)
- end
- matrix.symbol = symbol
- -- return matrix
- return matrix
- --///////////////--
- --// chillcode //--
- --///////////////--
Add Comment
Please, Sign In to add comment