Advertisement
cirossmonteiro

tensor

Feb 19th, 2025 (edited)
159
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.68 KB | Source Code | 0 0
  1. from typing import Self
  2. import unittest
  3.  
  4. # O(k)
  5. def product(dimensions: list[int]):
  6.     p = 1
  7.     for d in dimensions:
  8.         p *= d
  9.     return p
  10.  
  11. class Tensor:
  12.     index = []
  13.     def __init__(self, dimensions: list[int]):
  14.         self.dimensions = dimensions
  15.         p = product(dimensions)
  16.         self.values: list = [None for n in range(p)]
  17.         def assign(index, value):
  18.             self.values[index] = value
  19.         self.assign = assign
  20.  
  21.     @property
  22.     def order(self):
  23.         return len(self.dimensions)
  24.    
  25.     # O(k)
  26.     def compute_tensor_index(self, index: int) -> list[int]:
  27.         final = [0 for _ in self.dimensions]
  28.         p = product(self.dimensions) # this is computed too many times for the same array 'dimensions'
  29.         r = 0
  30.         for i, dimension in enumerate(self.dimensions):
  31.             p //= dimension
  32.             final[i] = (index - r) // p
  33.             r += final[i]*p
  34.         return final
  35.  
  36.     # O(k)
  37.     def compute_linear_index(self, index: list[int]) -> int:
  38.         if len(index) != len(self.dimensions):
  39.             raise Exception("Index's length and dimensions's length MUST be equal.")
  40.         p, final = 1, 0
  41.         for i, dimension in enumerate(self.dimensions[::-1]):
  42.             final += index[-i-1]*p
  43.             p *= dimension
  44.         return final
  45.  
  46.     def __getitem__(self, index):
  47.         if type(index) == list:
  48.             if len(index) == len(self.dimensions):
  49.                 pos = self.compute_linear_index(index)
  50.                 return self.values[pos]
  51.             else:
  52.                 raise IndexError("Index and dimensions MUST have the same length.")
  53.         elif len(self.index) == len(self.dimensions)-1:
  54.             index = [*self.index, index]
  55.             pos = self.compute_linear_index(index)
  56.             return self.values[pos]
  57.         else:
  58.             newt = Tensor(self.dimensions[:])
  59.             newt.values = self.values[:]
  60.             newt.index = [*self.index, index]
  61.             newt.assign = self.assign
  62.             return newt
  63.    
  64.     def __setitem__(self, index, value):
  65.         index = [*self.index, index]
  66.         if len(index) == len(self.dimensions):
  67.             pos = self.compute_linear_index(index)
  68.             self.assign(pos, value)
  69.         elif isinstance(value, Tensor):
  70.             if len(self.dimensions) == len(index) + len(value.dimensions):
  71.                 index_extended = [*index, *[0 for _ in range(len(value.dimensions))]]
  72.                 pos = self.compute_linear_index(index_extended)
  73.                 for i, x in enumerate(value.values):
  74.                     self.assign(pos + i, x)
  75.             else:
  76.                 raise IndexError("Tensor provided is not compatible with current index.")
  77.         else:
  78.             raise IndexError("Not enough indexes provided.")
  79.    
  80.     def __mul__(self, other: Self):
  81.         newt = Tensor([*self.dimensions, *other.dimensions])
  82.         p2 = product(other.dimensions)
  83.         for pos1, v1 in enumerate(self.values):
  84.             for pos2, v2 in enumerate(other.values):
  85.                 # first approach - bad!
  86.                 index1 = self.compute_tensor_index(pos1) # O(k1)
  87.                 index2 = other.compute_tensor_index(pos2) # O(k2)
  88.                 newpos = newt.compute_linear_index([*index1,*index2]) # O(k1+k2)
  89.  
  90.                 # second approach - good!
  91.                 newpos2 = pos1 * p2 + pos2
  92.                 assert newpos == newpos2
  93.                 newt.assign(newpos, v1 * v2)
  94.  
  95.         return newt
  96.    
  97.     def contraction(self, i, j):
  98.         if self.dimensions[i] == self.dimensions[j]:
  99.             dimensions = [d for k, d in enumerate(self.dimensions) if k not in [i,j]]
  100.             newt = Tensor(dimensions)
  101.             newt.values = [0 for _ in range(product(dimensions))]
  102.  
  103.             for pos, v in enumerate(self.values):
  104.                 index = self.compute_tensor_index(pos)
  105.                 if index[i] == index[j]:
  106.                     new_index = [d for k, d in enumerate(index) if k not in [i,j]]
  107.                     new_pos = newt.compute_linear_index(new_index)
  108.                     newt.values[new_pos] += v
  109.             return newt
  110.         else:
  111.             raise IndexError("Bad indices, they MUST have the same dimension.")
  112.  
  113. class TestTensorMethods(unittest.TestCase):
  114.     def test_computes(self):
  115.         linear_index = 2*4*5+3*5+4
  116.         tensor_index = [2,3,4]
  117.         dimensions = [3,4,5]
  118.         tensor = Tensor(dimensions)
  119.         self.assertEqual(
  120.             tensor.compute_linear_index(tensor_index),
  121.             linear_index
  122.         )
  123.         self.assertListEqual(
  124.             tensor.compute_tensor_index(linear_index),
  125.             tensor_index
  126.         )
  127.  
  128.     def test_element_assign(self):
  129.         dimensions = [3,4,5,6]
  130.         t1 = Tensor(dimensions)
  131.         t1.values = list(range(product(dimensions)))
  132.         self.assertEqual(t1.order, len(dimensions))
  133.         for i in range(3):
  134.             for j in range(4):
  135.                 for k in range(5):
  136.                     for l in range(6):
  137.                         v = i*4*5*6 + j*5*6 + k*6 + l
  138.                         self.assertEqual(t1[i][j][k][l], v)
  139.                         t1[i][j][k][l] = v+1
  140.                         self.assertEqual(t1[i][j][k][l], v+1)
  141.  
  142.     def test_tensor_assign(self):
  143.         d1, d2 = [3,4,5,6], [5,6]
  144.         t1, t2 = Tensor(d1), Tensor(d2)
  145.         t1.values = list(range(product(d1)))
  146.         for k in range(5):
  147.             for l in range(6):
  148.                 t2[k][l] = k*6+l
  149.         for i in range(3):
  150.             for j in range(4):
  151.                 t1[i][j] = t2
  152.                 for k in range(5):
  153.                     for l in range(6):
  154.                         self.assertEqual(t1[i][j][k][l], k*6 + l)
  155.    
  156.     def test_tensor_mult(self):
  157.         d1, d2 = [3,4], [5,6,7]
  158.         t1, t2 = Tensor(d1), Tensor(d2)
  159.         t1.values = list(range(product(d1)))
  160.         t2.values = list(range(product(d2)))
  161.         t3 = t1 * t2
  162.         self.assertEqual(t1.order, len(d1))
  163.         self.assertEqual(t2.order, len(d2))
  164.         for i in range(3):
  165.             for j in range(4):
  166.                 for k in range(5):
  167.                     for l in range(6):
  168.                         for m in range(7):
  169.                             self.assertEqual(
  170.                                 t1[i][j] * t2[k][l][m],
  171.                                 t3[i][j][k][l][m]
  172.                             )
  173.    
  174.     def test_matrix_mult(self):
  175.         d1 = d2 = [2,2]
  176.         t1 = Tensor(d1)
  177.         t2 = Tensor(d2)
  178.         t1.values = [1,2,3,4]
  179.         t2.values = [5,6,7,8]
  180.         newt = (t1*t2).contraction(1,2)
  181.         self.assertListEqual(newt.values, [19,22,43,50])
  182.  
  183.  
  184.  
  185. if __name__ == '__main__':
  186.     unittest.main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement