Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def ps(name,layer):
- print(f"Shape of {name} is {layer.shape}")
- import torch
- def crop_resize(tensor,crop_size):
- # Get the dimensions of the tensor
- _,height, width, channels = tensor.shape
- # Calculate the starting and ending indices of the center crop
- start_index = (height - crop_size) // 2
- end_index = start_index + crop_size
- # Crop the center of the tensor
- cropped = tensor[:,start_index:end_index, start_index:end_index, :]
- return cropped
- input_layer = Input((572, 572, 1))
- conv1 = Conv2D(64, 3, activation="relu")(input_layer)
- ps("conv1",conv1)
- conv1 = Conv2D(64, 3, activation="relu")(conv1)
- ps("conv1",conv1)
- pool1 = MaxPooling2D((2,2),strides=2)(conv1)
- #pool1 = Dropout(0.25)(pool1)
- ps("pool1",pool1)
- conv2 = Conv2D(128, 3, activation="relu")(pool1)
- ps("conv2",conv2)
- conv2 = Conv2D(128, (3, 3), activation="relu")(conv2)
- ps("conv2",conv2)
- pool2 = MaxPooling2D((2, 2),strides=2)(conv2)
- #pool2 = Dropout(0.5)(pool2)
- ps("pool2",pool2)
- conv3 = Conv2D(256, (3, 3), activation="relu")(pool2)
- ps("conv3",conv3)
- conv3 = Conv2D(256, (3, 3), activation="relu")(conv3)
- ps("conv3",conv3)
- pool3 = MaxPooling2D((2, 2))(conv3)
- ps("pool3",pool3)
- #pool3 = Dropout(0.5)(pool3)
- conv4 = Conv2D(512, (3, 3), activation="relu")(pool3)
- ps("conv4",conv4)
- conv4 = Conv2D(512, (3, 3), activation="relu")(conv4)
- ps("conv4",conv4)
- pool4 = MaxPooling2D((2, 2))(conv4)
- pool4 = Dropout(0.5)(pool4)
- ps("pool4",pool4)
- # Middle
- convm = Conv2D(1024, (3, 3), activation="relu")(pool4)
- ps("convm",convm)
- convm = Conv2D(1024, (3, 3), activation="relu")(convm)
- ps("convm",convm)
- deconv4 = Conv2DTranspose(1024, (2,2),strides=2)(convm)
- ps("deconv4",deconv4)
- conv4_cropped = crop_resize(conv4,56)
- ps("uconv4_cropped",conv4_cropped)
- uconv4 = concatenate([deconv4, conv4_cropped])
- ps("uconv4_concaterated",uconv4)
- uconv4 = Conv2D(512, (3, 3), activation="relu")(uconv4)
- ps("uconv4",uconv4)
- uconv4 = Conv2D(512, (3, 3), activation="relu")(uconv4)
- ps("uconv4",uconv4)
- deconv3 = Conv2DTranspose(512, (2,2), strides=2)(uconv4)
- ps("deconv3",deconv3)
- conv3_cropped = crop_resize(conv3,104)
- uconv3 = concatenate([deconv3, conv3_cropped])
- uconv3 = Conv2D(256,(3,3),activation="relu")(uconv3)
- ps("uconv3",uconv3)
- uconv3 = Conv2D(256,(3,3),activation="relu")(uconv3)
- ps("uconv3",uconv3)
- deconv2 = Conv2DTranspose(256, (2,2),strides=2)(uconv3)
- conv2_cropped = crop_resize(conv2,200)
- uconv2 = concatenate([deconv2, conv2_cropped])
- uconv2 = Conv2D(128,(3,3),activation="relu")(uconv2)
- ps("uconv2",uconv2)
- uconv2 = Conv2D(128,(3,3),activation="relu")(uconv2)
- ps("uconv2",uconv2)
- deconv1 = Conv2DTranspose(128, (2,2),strides=2)(uconv2)
- ps("deconv1",deconv1)
- conv1_cropped = crop_resize(conv1,392)
- uconv1 = concatenate([deconv1, conv1_cropped])
- ps("uconv1",uconv1)
- uconv1 = Conv2D(64,(3,3),activation="relu")(uconv1)
- ps("uconv1",uconv1)
- uconv1 = Conv2D(64,(3,3),activation="relu")(uconv1)
- ps("uconv1",uconv1)
- output = Conv2D(1,(1,1),padding="same",activation="sigmoid")(uconv1)
- ps("output",output)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement