Advertisement
finySTAR

Untitled

Mar 11th, 2025
94
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.26 KB | Source Code | 0 0
  1. import torch
  2. from torch.utils.data import Dataset
  3. import torchvision.transforms as transforms
  4. from torchvision.datasets import Caltech101
  5. import matplotlib.pyplot as plt
  6.  
  7. class GrayscaleToColorDataset(Dataset):
  8.     def __init__(self, root="./data", train=True, transform_gray=None, transform_color=None):
  9.         """
  10.        Custom Dataset that returns (grayscale_image, original_color_image) pairs.
  11.  
  12.        Args:
  13.            root (str): Path to store the dataset.
  14.            train (bool): Whether to use the training split.
  15.            transform_gray (callable, optional): Transform to apply to grayscale images.
  16.            transform_color (callable, optional): Transform to apply to original color images.
  17.        """
  18.         self.dataset = Caltech101(root=root, download=True)  # Load dataset
  19.  
  20.  
  21.         self.transform_gray = transform_gray if transform_gray else transforms.ToTensor()
  22.         self.transform_color = transform_color if transform_color else transforms.ToTensor()
  23.  
  24.     def __len__(self):
  25.         return len(self.dataset)
  26.  
  27.     def __getitem__(self, idx):
  28.         """
  29.        Returns a tuple (grayscale_image, original_color_image).
  30.        """
  31.         image, label = self.dataset[idx]
  32.  
  33.         grayscale_transform = transforms.Grayscale(num_output_channels=1)
  34.         grayscale_image = grayscale_transform(image)
  35.  
  36.         grayscale_image = self.transform_gray(grayscale_image)
  37.         color_image = self.transform_color(image)
  38.  
  39.         return grayscale_image, color_image, label
  40.  
  41. transform_gray = transforms.Compose([
  42.     transforms.Resize((224, 224)),
  43.     transforms.Grayscale(num_output_channels=1), #Fel nem tudom fogni ez miért működött
  44.     transforms.ToTensor(),
  45.     transforms.Normalize(mean=[0.5], std=[0.5])
  46. ])
  47.  
  48. transform_color = transforms.Compose([
  49.     #transforms.Grayscale(num_output_channels=1), #E nélkül 3 csatornás lesz és elszáll a vizualizáció
  50.     #Brodcastolással átalakítja az 1 csat képeket 3 ra így nem száll el
  51.     transforms.Resize((224, 224)),
  52.     transforms.ToTensor()
  53. ])
  54.  
  55. # Dataset instance
  56. dataset_b = GrayscaleToColorDataset(transform_gray=transform_gray, transform_color=transform_color)
  57.  
  58. # DataLoader
  59. dataloader_primary = torch.utils.data.DataLoader(dataset_b, batch_size=8, shuffle=True)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement