Advertisement
exihs

Untitled

Jan 29th, 2025
10
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.02 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torchvision.transforms as transforms
  4. from PIL import Image
  5. import os
  6. import cv2
  7. import numpy as np
  8. import mediapipe as mp
  9.  
  10.  
  11. class imageProcessor():
  12. def __init__(self):
  13. self.transform = transforms.Compose([
  14. transforms.ToPILImage(),
  15. transforms.Grayscale(),
  16. transforms.Resize((64, 64)),
  17. transforms.ToTensor(),
  18. transforms.Normalize(mean=[0.5], std=[0.5])
  19. ])
  20. mp_hands = mp.solutions.hands
  21. self.hands = mp_hands.Hands(static_image_mode=True, min_detection_confidence=0.5)
  22.  
  23. def __detect_hand_mediapipe(self, img):
  24. """Detect and crop hand using MediaPipe Hands."""
  25. img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  26. results = self.hands.process(img_rgb)
  27.  
  28. if results.multi_hand_landmarks:
  29. for hand_landmarks in results.multi_hand_landmarks:
  30. # Get bounding box coordinates
  31. x_min = min([lm.x for lm in hand_landmarks.landmark])
  32. y_min = min([lm.y for lm in hand_landmarks.landmark])
  33. x_max = max([lm.x for lm in hand_landmarks.landmark])
  34. y_max = max([lm.y for lm in hand_landmarks.landmark])
  35.  
  36. h, w, _ = img.shape
  37. x_min, y_min = int(x_min * w), int(y_min * h)
  38. x_max, y_max = int(x_max * w), int(y_max * h)
  39.  
  40. # Crop hand region with padding
  41. padding = 20
  42. x_min, y_min = max(0, x_min - padding), max(0, y_min - padding)
  43. x_max, y_max = min(w, x_max + padding), min(h, y_max + padding)
  44.  
  45. cropped_img = img[y_min:y_max, x_min:x_max]
  46. return cropped_img
  47.  
  48. return img # Return original image if no hand is detected
  49.  
  50. def processImage(self, img):
  51. """Preprocess images: Detect & crop hand, apply edge detection"""
  52. cropped_hand = self.__detect_hand_mediapipe(img)
  53. gray_img = cv2.cvtColor(cropped_hand, cv2.COLOR_BGR2GRAY)
  54. edges = cv2.Canny(gray_img, threshold1=50, threshold2=150)
  55. return self.transform(edges).unsqueeze(0)
  56.  
  57. class cnnModel():
  58. def __init__(self, C=1, D=64*64, classes=29, filters=64):
  59. self.model = nn.Sequential(
  60. nn.Conv2d(C, filters, 3, padding=1),
  61. nn.BatchNorm2d(filters),
  62. nn.ReLU(),
  63. nn.Conv2d(filters, filters, 3, padding=1),
  64. nn.BatchNorm2d(filters),
  65. nn.ReLU(),
  66. nn.MaxPool2d(2, stride=2),
  67.  
  68. nn.Conv2d(filters, 2*filters, 3, padding=1),
  69. nn.BatchNorm2d(2*filters),
  70. nn.ReLU(),
  71. nn.Conv2d(2*filters, 2*filters, 3, padding=1),
  72. nn.BatchNorm2d(2*filters),
  73. nn.ReLU(),
  74. nn.MaxPool2d(2, stride=2),
  75.  
  76. nn.Conv2d(2*filters, 4*filters, 3, padding=1),
  77. nn.BatchNorm2d(4*filters),
  78. nn.ReLU(),
  79. nn.Conv2d(4*filters, 4*filters, 3, padding=1),
  80. nn.BatchNorm2d(4*filters),
  81. nn.ReLU(),
  82. nn.Conv2d(4*filters, 4*filters, 3, padding=1),
  83. nn.BatchNorm2d(4*filters),
  84. nn.ReLU(),
  85. nn.Conv2d(4*filters, 4*filters, 3, padding=1),
  86. nn.BatchNorm2d(4*filters),
  87. nn.ReLU(),
  88. nn.MaxPool2d(2, stride=2),
  89.  
  90. nn.Conv2d(4*filters, 8*filters, 3, padding=1),
  91. nn.BatchNorm2d(8*filters),
  92. nn.ReLU(),
  93. nn.Conv2d(8*filters, 8*filters, 3, padding=1),
  94. nn.BatchNorm2d(8*filters),
  95. nn.ReLU(),
  96. nn.Conv2d(8*filters, 8*filters, 3, padding=1),
  97. nn.BatchNorm2d(8*filters),
  98. nn.ReLU(),
  99. nn.Conv2d(8*filters, 8*filters, 3, padding=1),
  100. nn.BatchNorm2d(8*filters),
  101. nn.ReLU(),
  102. nn.Dropout(0.2),
  103. nn.MaxPool2d(2, stride=2),
  104.  
  105. nn.Flatten(),
  106. nn.Linear(8 * filters * D // (16**2), 256),
  107. nn.ReLU(),
  108. nn.Dropout(0.5),
  109. nn.Linear(256, classes)
  110. )
  111.  
  112. self.classNames = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "del", "nothing", "space"]
  113.  
  114. def loadWeights(self, path):
  115. """Load trained weights from a specified file"""
  116. self.model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
  117.  
  118. def predictImage(self, img):
  119. """Return the class probabilities for a 64x64, 1 channel (black and white), edge detected image"""
  120. # Set the model to evaluation mode
  121. self.model.eval()
  122. with torch.no_grad():
  123. output = self.model(img)
  124. # Apply softmax to get probabilities
  125. probabilities = torch.softmax(output, dim=1)
  126.  
  127. # Convert to a list of (class_index, probability) tuples
  128. return [(self.classNames[i], round(prob.item(),5)) for i, prob in enumerate(probabilities.squeeze())]
  129.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement