Advertisement
Singasking

Untitled

Apr 21st, 2023
882
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.10 KB | None | 0 0
  1.  
  2.  
  3. def ps(name,layer):
  4.     print(f"Shape of {name} is {layer.shape}")
  5.    
  6.    
  7.  
  8. import torch
  9.  
  10. def crop_resize(tensor,crop_size):
  11.    
  12.     # Get the dimensions of the tensor
  13.     _,height, width, channels = tensor.shape
  14.    
  15.    
  16.    
  17.     # Calculate the starting and ending indices of the center crop
  18.     start_index = (height - crop_size) // 2
  19.     end_index = start_index + crop_size
  20.    
  21.     # Crop the center of the tensor
  22.     cropped = tensor[:,start_index:end_index, start_index:end_index, :]
  23.     return cropped
  24.  
  25.  
  26. input_layer = Input((572, 572, 1))
  27.  
  28. conv1 = Conv2D(64, 3, activation="relu")(input_layer)
  29. ps("conv1",conv1)
  30.  
  31. conv1 = Conv2D(64, 3, activation="relu")(conv1)
  32. ps("conv1",conv1)
  33. pool1 = MaxPooling2D((2,2),strides=2)(conv1)
  34. #pool1 = Dropout(0.25)(pool1)
  35. ps("pool1",pool1)
  36.  
  37.  
  38.  
  39. conv2 = Conv2D(128, 3, activation="relu")(pool1)
  40. ps("conv2",conv2)
  41. conv2 = Conv2D(128, (3, 3), activation="relu")(conv2)
  42. ps("conv2",conv2)
  43. pool2 = MaxPooling2D((2, 2),strides=2)(conv2)
  44. #pool2 = Dropout(0.5)(pool2)
  45. ps("pool2",pool2)
  46.  
  47. conv3 = Conv2D(256, (3, 3), activation="relu")(pool2)
  48. ps("conv3",conv3)
  49. conv3 = Conv2D(256, (3, 3), activation="relu")(conv3)
  50. ps("conv3",conv3)
  51. pool3 = MaxPooling2D((2, 2))(conv3)
  52. ps("pool3",pool3)
  53. #pool3 = Dropout(0.5)(pool3)
  54.  
  55.  
  56. conv4 = Conv2D(512, (3, 3), activation="relu")(pool3)
  57. ps("conv4",conv4)
  58. conv4 = Conv2D(512, (3, 3), activation="relu")(conv4)
  59. ps("conv4",conv4)
  60. pool4 = MaxPooling2D((2, 2))(conv4)
  61. pool4 = Dropout(0.5)(pool4)
  62. ps("pool4",pool4)
  63. # Middle
  64. convm = Conv2D(1024, (3, 3), activation="relu")(pool4)
  65. ps("convm",convm)
  66. convm = Conv2D(1024, (3, 3), activation="relu")(convm)
  67.  
  68. ps("convm",convm)
  69.        
  70.  
  71. deconv4 = Conv2DTranspose(1024, (2,2),strides=2)(convm)
  72. ps("deconv4",deconv4)
  73. conv4_cropped = crop_resize(conv4,56)
  74. ps("uconv4_cropped",conv4_cropped)
  75. uconv4 = concatenate([deconv4, conv4_cropped])
  76. ps("uconv4_concaterated",uconv4)
  77.  
  78. uconv4 = Conv2D(512, (3, 3), activation="relu")(uconv4)
  79. ps("uconv4",uconv4)
  80. uconv4 = Conv2D(512, (3, 3), activation="relu")(uconv4)
  81. ps("uconv4",uconv4)
  82.  
  83.  
  84.  
  85. deconv3 = Conv2DTranspose(512, (2,2), strides=2)(uconv4)
  86. ps("deconv3",deconv3)
  87. conv3_cropped = crop_resize(conv3,104)
  88. uconv3 = concatenate([deconv3, conv3_cropped])
  89. uconv3 = Conv2D(256,(3,3),activation="relu")(uconv3)
  90. ps("uconv3",uconv3)
  91. uconv3 = Conv2D(256,(3,3),activation="relu")(uconv3)
  92. ps("uconv3",uconv3)
  93.  
  94.  
  95. deconv2 = Conv2DTranspose(256, (2,2),strides=2)(uconv3)
  96. conv2_cropped = crop_resize(conv2,200)
  97. uconv2 = concatenate([deconv2, conv2_cropped])
  98. uconv2 = Conv2D(128,(3,3),activation="relu")(uconv2)
  99. ps("uconv2",uconv2)
  100. uconv2 = Conv2D(128,(3,3),activation="relu")(uconv2)
  101. ps("uconv2",uconv2)
  102.  
  103.  
  104. deconv1 = Conv2DTranspose(128, (2,2),strides=2)(uconv2)
  105. ps("deconv1",deconv1)
  106.  
  107. conv1_cropped = crop_resize(conv1,392)
  108. uconv1 = concatenate([deconv1, conv1_cropped])
  109. ps("uconv1",uconv1)
  110. uconv1 = Conv2D(64,(3,3),activation="relu")(uconv1)
  111. ps("uconv1",uconv1)
  112. uconv1 = Conv2D(64,(3,3),activation="relu")(uconv1)
  113. ps("uconv1",uconv1)
  114. output = Conv2D(1,(1,1),padding="same",activation="sigmoid")(uconv1)
  115. ps("output",output)
  116.  
  117.  
  118.  
  119.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement