Advertisement
MARSHAL327

Untitled

Jan 8th, 2024
80
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.01 KB | None | 0 0
  1. import numpy as np
  2. cimport numpy as np
  3. cimport cython
  4.  
  5. # DTYPE = np.float64
  6. # ctypedef np.float64_t DTYPE_t
  7.  
  8. ctypedef fused DTYPE_t:
  9. np.float32_t
  10. np.float64_t
  11.  
  12. def im2col_cython(np.ndarray[DTYPE_t, ndim=4] x, int field_height,
  13. int field_width, int padding, int stride):
  14. cdef int N = x.shape[0]
  15. cdef int C = x.shape[1]
  16. cdef int H = x.shape[2]
  17. cdef int W = x.shape[3]
  18.  
  19. cdef int HH = (H + 2 * padding - field_height) / int(stride) + 1
  20. cdef int WW = (W + 2 * padding - field_width) / int(stride) + 1
  21.  
  22. cdef int p = padding
  23. cdef np.ndarray[DTYPE_t, ndim=4] x_padded = np.pad(x,
  24. ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')
  25.  
  26. cdef np.ndarray[DTYPE_t, ndim=2] cols = np.zeros(
  27. (C * field_height * field_width, N * HH * WW),
  28. dtype=x.dtype)
  29.  
  30. # Moving the inner loop to a C function with no bounds checking works, but does
  31. # not seem to help performance in any measurable way.
  32.  
  33. im2col_cython_inner(cols, x_padded, N, C, H, W, HH, WW,
  34. field_height, field_width, padding, stride)
  35. return cols
  36.  
  37.  
  38. @cython.boundscheck(False)
  39. cdef int im2col_cython_inner(np.ndarray[DTYPE_t, ndim=2] cols,
  40. np.ndarray[DTYPE_t, ndim=4] x_padded,
  41. int N, int C, int H, int W, int HH, int WW,
  42. int field_height, int field_width, int padding, int stride) except? -1:
  43. cdef int c, ii, jj, row, yy, xx, i, col
  44.  
  45. for c in range(C):
  46. for yy in range(HH):
  47. for xx in range(WW):
  48. for ii in range(field_height):
  49. for jj in range(field_width):
  50. row = c * field_width * field_height + ii * field_height + jj
  51. for i in range(N):
  52. col = yy * WW * N + xx * N + i
  53. cols[row, col] = x_padded[i, c, stride * yy + ii, stride * xx + jj]
  54.  
  55.  
  56.  
  57. def col2im_cython(np.ndarray[DTYPE_t, ndim=2] cols, int N, int C, int H, int W,
  58. int field_height, int field_width, int padding, int stride):
  59. cdef np.ndarray x = np.empty((N, C, H, W), dtype=cols.dtype)
  60. cdef int HH = (H + 2 * padding - field_height) / int(stride) + 1
  61. cdef int WW = (W + 2 * padding - field_width) / int(stride) + 1
  62. cdef np.ndarray[DTYPE_t, ndim=4] x_padded = np.zeros((N, C, H + 2 * padding, W + 2 * padding),
  63. dtype=cols.dtype)
  64.  
  65. # Moving the inner loop to a C-function with no bounds checking improves
  66. # performance quite a bit for col2im.
  67. col2im_cython_inner(cols, x_padded, N, C, H, W, HH, WW,
  68. field_height, field_width, padding, stride)
  69. if padding > 0:
  70. return x_padded[:, :, padding:-padding, padding:-padding]
  71. return x_padded
  72.  
  73.  
  74. @cython.boundscheck(False)
  75. cdef int col2im_cython_inner(np.ndarray[DTYPE_t, ndim=2] cols,
  76. np.ndarray[DTYPE_t, ndim=4] x_padded,
  77. int N, int C, int H, int W, int HH, int WW,
  78. int field_height, int field_width, int padding, int stride) except? -1:
  79. cdef int c, ii, jj, row, yy, xx, i, col
  80.  
  81. for c in range(C):
  82. for ii in range(field_height):
  83. for jj in range(field_width):
  84. row = c * field_width * field_height + ii * field_height + jj
  85. for yy in range(HH):
  86. for xx in range(WW):
  87. for i in range(N):
  88. col = yy * WW * N + xx * N + i
  89. x_padded[i, c, stride * yy + ii, stride * xx + jj] += cols[row, col]
  90.  
  91.  
  92. @cython.boundscheck(False)
  93. @cython.wraparound(False)
  94. cdef col2im_6d_cython_inner(np.ndarray[DTYPE_t, ndim=6] cols,
  95. np.ndarray[DTYPE_t, ndim=4] x_padded,
  96. int N, int C, int H, int W, int HH, int WW,
  97. int out_h, int out_w, int pad, int stride):
  98.  
  99. cdef int c, hh, ww, n, h, w
  100. for n in range(N):
  101. for c in range(C):
  102. for hh in range(HH):
  103. for ww in range(WW):
  104. for h in range(out_h):
  105. for w in range(out_w):
  106. x_padded[n, c, stride * h + hh, stride * w + ww] += cols[c, hh, ww, n, h, w]
  107.  
  108.  
  109. def col2im_6d_cython(np.ndarray[DTYPE_t, ndim=6] cols, int N, int C, int H, int W,
  110. int HH, int WW, int pad, int stride):
  111. cdef np.ndarray x = np.empty((N, C, H, W), dtype=cols.dtype)
  112. cdef int out_h = (H + 2 * pad - HH) / int(stride) + 1
  113. cdef int out_w = (W + 2 * pad - WW) / int(stride) + 1
  114. cdef np.ndarray[DTYPE_t, ndim=4] x_padded = np.zeros((N, C, H + 2 * pad, W + 2 * pad),
  115. dtype=cols.dtype)
  116.  
  117. col2im_6d_cython_inner(cols, x_padded, N, C, H, W, HH, WW, out_h, out_w, pad, stride)
  118.  
  119. if pad > 0:
  120. return x_padded[:, :, pad:-pad, pad:-pad]
  121. return x_padded
  122.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement