Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- from torch.utils.data import Dataset
- import torchvision.transforms as transforms
- from torchvision.datasets import Caltech101
- import matplotlib.pyplot as plt
- class GrayscaleToColorDataset(Dataset):
- def __init__(self, root="./data", train=True, transform_gray=None, transform_color=None):
- """
- Custom Dataset that returns (grayscale_image, original_color_image) pairs.
- Args:
- root (str): Path to store the dataset.
- train (bool): Whether to use the training split.
- transform_gray (callable, optional): Transform to apply to grayscale images.
- transform_color (callable, optional): Transform to apply to original color images.
- """
- self.dataset = Caltech101(root=root, download=True) # Load dataset
- self.transform_gray = transform_gray if transform_gray else transforms.ToTensor()
- self.transform_color = transform_color if transform_color else transforms.ToTensor()
- def __len__(self):
- return len(self.dataset)
- def __getitem__(self, idx):
- """
- Returns a tuple (grayscale_image, original_color_image).
- """
- image, label = self.dataset[idx]
- grayscale_transform = transforms.Grayscale(num_output_channels=1)
- grayscale_image = grayscale_transform(image)
- grayscale_image = self.transform_gray(grayscale_image)
- color_image = self.transform_color(image)
- return grayscale_image, color_image, label
- transform_gray = transforms.Compose([
- transforms.Resize((224, 224)),
- transforms.Grayscale(num_output_channels=1), #Fel nem tudom fogni ez miért működött
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.5], std=[0.5])
- ])
- transform_color = transforms.Compose([
- #transforms.Grayscale(num_output_channels=1), #E nélkül 3 csatornás lesz és elszáll a vizualizáció
- #Brodcastolással átalakítja az 1 csat képeket 3 ra így nem száll el
- transforms.Resize((224, 224)),
- transforms.ToTensor()
- ])
- # Dataset instance
- dataset_b = GrayscaleToColorDataset(transform_gray=transform_gray, transform_color=transform_color)
- # DataLoader
- dataloader_primary = torch.utils.data.DataLoader(dataset_b, batch_size=8, shuffle=True)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement